Browse Source

Add support for specifying a default_response_class (#467)

pull/568/head
toppk 6 years ago
committed by Sebastián Ramírez
parent
commit
f803c77515
  1. 1
      Pipfile
  2. 45
      fastapi/applications.py
  3. 8
      fastapi/openapi/utils.py
  4. 31
      fastapi/routing.py
  5. 1
      pyproject.toml
  6. 216
      tests/test_default_response_class.py
  7. 206
      tests/test_default_response_class_router.py

1
Pipfile

@ -29,6 +29,7 @@ starlette = "==0.12.8"
pydantic = "==0.32.2"
databases = {extras = ["sqlite"],version = "*"}
hypercorn = "*"
orjson = "*"
[requires]
python_version = "3.6"

45
fastapi/applications.py

@ -33,11 +33,13 @@ class FastAPI(Starlette):
version: str = "0.1.0",
openapi_url: Optional[str] = "/openapi.json",
openapi_prefix: str = "",
default_response_class: Type[Response] = JSONResponse,
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
**extra: Dict[str, Any],
) -> None:
self.default_response_class = default_response_class
self._debug = debug
self.router: routing.APIRouter = routing.APIRouter(
routes, dependency_overrides_provider=self
@ -144,7 +146,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> None:
self.router.add_api_route(
@ -166,7 +168,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -190,7 +192,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
@ -213,7 +215,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
return func
@ -240,6 +242,7 @@ class FastAPI(Starlette):
tags: List[str] = None,
dependencies: Sequence[Depends] = None,
responses: Dict[Union[int, str], Dict[str, Any]] = None,
default_response_class: Optional[Type[Response]] = None,
) -> None:
self.router.include_router(
router,
@ -247,6 +250,8 @@ class FastAPI(Starlette):
tags=tags,
dependencies=dependencies,
responses=responses or {},
default_response_class=default_response_class
or self.default_response_class,
)
def get(
@ -268,7 +273,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.get(
@ -288,7 +293,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -311,7 +316,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.put(
@ -331,7 +336,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -354,7 +359,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.post(
@ -374,7 +379,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -397,7 +402,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.delete(
@ -417,7 +422,7 @@ class FastAPI(Starlette):
operation_id=operation_id,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -440,7 +445,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.options(
@ -460,7 +465,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -483,7 +488,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.head(
@ -503,7 +508,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -526,7 +531,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.patch(
@ -546,7 +551,7 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
@ -569,7 +574,7 @@ class FastAPI(Starlette):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.trace(
@ -589,6 +594,6 @@ class FastAPI(Starlette):
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

8
fastapi/openapi/utils.py

@ -151,6 +151,10 @@ def get_openapi_path(
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
assert (
route.response_class and route.response_class.media_type
), "A response class with media_type is needed to generate OpenAPI"
route_response_media_type: str = route.response_class.media_type
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
@ -185,7 +189,7 @@ def get_openapi_path(
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
response.setdefault("content", {}).setdefault(
route.response_class.media_type, {}
route_response_media_type, {}
)["schema"] = response_schema
status_text: Optional[str] = status_code_ranges.get(
str(additional_status_code).upper()
@ -213,7 +217,7 @@ def get_openapi_path(
] = route.response_description
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(route.response_class.media_type, {})[
).setdefault("content", {}).setdefault(route_response_media_type, {})[
"schema"
] = response_schema

31
fastapi/routing.py

@ -200,7 +200,7 @@ class APIRoute(routing.Route):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Optional[Type[Response]] = None,
dependency_overrides_provider: Any = None,
) -> None:
self.path = path
@ -215,9 +215,6 @@ class APIRoute(routing.Route):
)
self.response_model = response_model
if self.response_model:
assert lenient_issubclass(
response_class, JSONResponse
), "To declare a type the response must be a JSON response"
response_name = "Response_" + self.unique_id
self.response_field: Optional[Field] = Field(
name=response_name,
@ -299,7 +296,7 @@ class APIRoute(routing.Route):
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
response_class=self.response_class,
response_class=self.response_class or JSONResponse,
response_field=self.secure_cloned_response_field,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
@ -346,7 +343,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> None:
route = self.route_class(
@ -394,7 +391,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
@ -445,6 +442,7 @@ class APIRouter(routing.Router):
tags: List[str] = None,
dependencies: Sequence[params.Depends] = None,
responses: Dict[Union[int, str], Dict[str, Any]] = None,
default_response_class: Optional[Type[Response]] = None,
) -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
@ -484,7 +482,7 @@ class APIRouter(routing.Router):
response_model_by_alias=route.response_model_by_alias,
response_model_skip_defaults=route.response_model_skip_defaults,
include_in_schema=route.include_in_schema,
response_class=route.response_class,
response_class=route.response_class or default_response_class,
name=route.name,
)
elif isinstance(route, routing.Route):
@ -523,10 +521,9 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
path=path,
response_model=response_model,
@ -568,7 +565,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
@ -612,7 +609,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
@ -656,7 +653,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
@ -700,7 +697,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
@ -744,7 +741,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
@ -788,7 +785,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
@ -832,7 +829,7 @@ class APIRouter(routing.Router):
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(

1
pyproject.toml

@ -39,6 +39,7 @@ test = [
"email_validator",
"sqlalchemy",
"databases[sqlite]",
"orjson"
]
doc = [
"mkdocs",

216
tests/test_default_response_class.py

@ -0,0 +1,216 @@
from typing import Any
import orjson
from fastapi import APIRouter, FastAPI
from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse
from starlette.testclient import TestClient
class ORJSONResponse(JSONResponse):
media_type = "application/x-orjson"
def render(self, content: Any) -> bytes:
return orjson.dumps(content)
class OverrideResponse(JSONResponse):
media_type = "application/x-override"
app = FastAPI(default_response_class=ORJSONResponse)
router_a = APIRouter()
router_a_a = APIRouter()
router_a_b_override = APIRouter() # Overrides default class
router_b_override = APIRouter() # Overrides default class
router_b_a = APIRouter()
router_b_a_c_override = APIRouter() # Overrides default class again
@app.get("/")
def get_root():
return {"msg": "Hello World"}
@app.get("/override", response_class=PlainTextResponse)
def get_path_override():
return "Hello World"
@router_a.get("/")
def get_a():
return {"msg": "Hello A"}
@router_a.get("/override", response_class=PlainTextResponse)
def get_a_path_override():
return "Hello A"
@router_a_a.get("/")
def get_a_a():
return {"msg": "Hello A A"}
@router_a_a.get("/override", response_class=PlainTextResponse)
def get_a_a_path_override():
return "Hello A A"
@router_a_b_override.get("/")
def get_a_b():
return "Hello A B"
@router_a_b_override.get("/override", response_class=HTMLResponse)
def get_a_b_path_override():
return "Hello A B"
@router_b_override.get("/")
def get_b():
return "Hello B"
@router_b_override.get("/override", response_class=HTMLResponse)
def get_b_path_override():
return "Hello B"
@router_b_a.get("/")
def get_b_a():
return "Hello B A"
@router_b_a.get("/override", response_class=HTMLResponse)
def get_b_a_path_override():
return "Hello B A"
@router_b_a_c_override.get("/")
def get_b_a_c():
return "Hello B A C"
@router_b_a_c_override.get("/override", response_class=OverrideResponse)
def get_b_a_c_path_override():
return {"msg": "Hello B A C"}
router_b_a.include_router(
router_b_a_c_override, prefix="/c", default_response_class=HTMLResponse
)
router_b_override.include_router(router_b_a, prefix="/a")
router_a.include_router(router_a_a, prefix="/a")
router_a.include_router(
router_a_b_override, prefix="/b", default_response_class=PlainTextResponse
)
app.include_router(router_a, prefix="/a")
app.include_router(
router_b_override, prefix="/b", default_response_class=PlainTextResponse
)
client = TestClient(app)
orjson_type = "application/x-orjson"
text_type = "text/plain; charset=utf-8"
html_type = "text/html; charset=utf-8"
override_type = "application/x-override"
def test_app():
with client:
response = client.get("/")
assert response.json() == {"msg": "Hello World"}
assert response.headers["content-type"] == orjson_type
def test_app_override():
with client:
response = client.get("/override")
assert response.content == b"Hello World"
assert response.headers["content-type"] == text_type
def test_router_a():
with client:
response = client.get("/a")
assert response.json() == {"msg": "Hello A"}
assert response.headers["content-type"] == orjson_type
def test_router_a_override():
with client:
response = client.get("/a/override")
assert response.content == b"Hello A"
assert response.headers["content-type"] == text_type
def test_router_a_a():
with client:
response = client.get("/a/a")
assert response.json() == {"msg": "Hello A A"}
assert response.headers["content-type"] == orjson_type
def test_router_a_a_override():
with client:
response = client.get("/a/a/override")
assert response.content == b"Hello A A"
assert response.headers["content-type"] == text_type
def test_router_a_b():
with client:
response = client.get("/a/b")
assert response.content == b"Hello A B"
assert response.headers["content-type"] == text_type
def test_router_a_b_override():
with client:
response = client.get("/a/b/override")
assert response.content == b"Hello A B"
assert response.headers["content-type"] == html_type
def test_router_b():
with client:
response = client.get("/b")
assert response.content == b"Hello B"
assert response.headers["content-type"] == text_type
def test_router_b_override():
with client:
response = client.get("/b/override")
assert response.content == b"Hello B"
assert response.headers["content-type"] == html_type
def test_router_b_a():
with client:
response = client.get("/b/a")
assert response.content == b"Hello B A"
assert response.headers["content-type"] == text_type
def test_router_b_a_override():
with client:
response = client.get("/b/a/override")
assert response.content == b"Hello B A"
assert response.headers["content-type"] == html_type
def test_router_b_a_c():
with client:
response = client.get("/b/a/c")
assert response.content == b"Hello B A C"
assert response.headers["content-type"] == html_type
def test_router_b_a_c_override():
with client:
response = client.get("/b/a/c/override")
assert response.json() == {"msg": "Hello B A C"}
assert response.headers["content-type"] == override_type

206
tests/test_default_response_class_router.py

@ -0,0 +1,206 @@
from fastapi import APIRouter, FastAPI
from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse
from starlette.testclient import TestClient
class OverrideResponse(JSONResponse):
media_type = "application/x-override"
app = FastAPI()
router_a = APIRouter()
router_a_a = APIRouter()
router_a_b_override = APIRouter() # Overrides default class
router_b_override = APIRouter() # Overrides default class
router_b_a = APIRouter()
router_b_a_c_override = APIRouter() # Overrides default class again
@app.get("/")
def get_root():
return {"msg": "Hello World"}
@app.get("/override", response_class=PlainTextResponse)
def get_path_override():
return "Hello World"
@router_a.get("/")
def get_a():
return {"msg": "Hello A"}
@router_a.get("/override", response_class=PlainTextResponse)
def get_a_path_override():
return "Hello A"
@router_a_a.get("/")
def get_a_a():
return {"msg": "Hello A A"}
@router_a_a.get("/override", response_class=PlainTextResponse)
def get_a_a_path_override():
return "Hello A A"
@router_a_b_override.get("/")
def get_a_b():
return "Hello A B"
@router_a_b_override.get("/override", response_class=HTMLResponse)
def get_a_b_path_override():
return "Hello A B"
@router_b_override.get("/")
def get_b():
return "Hello B"
@router_b_override.get("/override", response_class=HTMLResponse)
def get_b_path_override():
return "Hello B"
@router_b_a.get("/")
def get_b_a():
return "Hello B A"
@router_b_a.get("/override", response_class=HTMLResponse)
def get_b_a_path_override():
return "Hello B A"
@router_b_a_c_override.get("/")
def get_b_a_c():
return "Hello B A C"
@router_b_a_c_override.get("/override", response_class=OverrideResponse)
def get_b_a_c_path_override():
return {"msg": "Hello B A C"}
router_b_a.include_router(
router_b_a_c_override, prefix="/c", default_response_class=HTMLResponse
)
router_b_override.include_router(router_b_a, prefix="/a")
router_a.include_router(router_a_a, prefix="/a")
router_a.include_router(
router_a_b_override, prefix="/b", default_response_class=PlainTextResponse
)
app.include_router(router_a, prefix="/a")
app.include_router(
router_b_override, prefix="/b", default_response_class=PlainTextResponse
)
client = TestClient(app)
json_type = "application/json"
text_type = "text/plain; charset=utf-8"
html_type = "text/html; charset=utf-8"
override_type = "application/x-override"
def test_app():
with client:
response = client.get("/")
assert response.json() == {"msg": "Hello World"}
assert response.headers["content-type"] == json_type
def test_app_override():
with client:
response = client.get("/override")
assert response.content == b"Hello World"
assert response.headers["content-type"] == text_type
def test_router_a():
with client:
response = client.get("/a")
assert response.json() == {"msg": "Hello A"}
assert response.headers["content-type"] == json_type
def test_router_a_override():
with client:
response = client.get("/a/override")
assert response.content == b"Hello A"
assert response.headers["content-type"] == text_type
def test_router_a_a():
with client:
response = client.get("/a/a")
assert response.json() == {"msg": "Hello A A"}
assert response.headers["content-type"] == json_type
def test_router_a_a_override():
with client:
response = client.get("/a/a/override")
assert response.content == b"Hello A A"
assert response.headers["content-type"] == text_type
def test_router_a_b():
with client:
response = client.get("/a/b")
assert response.content == b"Hello A B"
assert response.headers["content-type"] == text_type
def test_router_a_b_override():
with client:
response = client.get("/a/b/override")
assert response.content == b"Hello A B"
assert response.headers["content-type"] == html_type
def test_router_b():
with client:
response = client.get("/b")
assert response.content == b"Hello B"
assert response.headers["content-type"] == text_type
def test_router_b_override():
with client:
response = client.get("/b/override")
assert response.content == b"Hello B"
assert response.headers["content-type"] == html_type
def test_router_b_a():
with client:
response = client.get("/b/a")
assert response.content == b"Hello B A"
assert response.headers["content-type"] == text_type
def test_router_b_a_override():
with client:
response = client.get("/b/a/override")
assert response.content == b"Hello B A"
assert response.headers["content-type"] == html_type
def test_router_b_a_c():
with client:
response = client.get("/b/a/c")
assert response.content == b"Hello B A C"
assert response.headers["content-type"] == html_type
def test_router_b_a_c_override():
with client:
response = client.get("/b/a/c/override")
assert response.json() == {"msg": "Hello B A C"}
assert response.headers["content-type"] == override_type
Loading…
Cancel
Save