Browse Source

🐛 Fix removing body from status codes that do not support it (#5145)

pull/5133/head
Sebastián Ramírez 3 years ago
committed by GitHub
parent
commit
c43120258f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      fastapi/openapi/constants.py
  2. 12
      fastapi/openapi/utils.py
  3. 33
      fastapi/routing.py
  4. 7
      fastapi/utils.py
  5. 9
      tests/test_response_code_no_body.py

1
fastapi/openapi/constants.py

@ -1,3 +1,2 @@
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304}
REF_PREFIX = "#/components/schemas/" REF_PREFIX = "#/components/schemas/"

12
fastapi/openapi/utils.py

@ -9,11 +9,7 @@ from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import ( from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
METHODS_WITH_BODY,
REF_PREFIX,
STATUS_CODES_WITH_NO_BODY,
)
from fastapi.openapi.models import OpenAPI from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param from fastapi.params import Body, Param
from fastapi.responses import Response from fastapi.responses import Response
@ -21,6 +17,7 @@ from fastapi.utils import (
deep_dict_update, deep_dict_update,
generate_operation_id_for_path, generate_operation_id_for_path,
get_model_definitions, get_model_definitions,
is_body_allowed_for_status_code,
) )
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import ModelField, Undefined from pydantic.fields import ModelField, Undefined
@ -265,9 +262,8 @@ def get_openapi_path(
operation.setdefault("responses", {}).setdefault(status_code, {})[ operation.setdefault("responses", {}).setdefault(status_code, {})[
"description" "description"
] = route.response_description ] = route.response_description
if ( if route_response_media_type and is_body_allowed_for_status_code(
route_response_media_type route.status_code
and route.status_code not in STATUS_CODES_WITH_NO_BODY
): ):
response_schema = {"type": "string"} response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse): if lenient_issubclass(current_response_class, JSONResponse):

33
fastapi/routing.py

@ -29,13 +29,13 @@ from fastapi.dependencies.utils import (
) )
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
from fastapi.types import DecoratedCallable from fastapi.types import DecoratedCallable
from fastapi.utils import ( from fastapi.utils import (
create_cloned_field, create_cloned_field,
create_response_field, create_response_field,
generate_unique_id, generate_unique_id,
get_value_or_default, get_value_or_default,
is_body_allowed_for_status_code,
) )
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper, ValidationError from pydantic.error_wrappers import ErrorWrapper, ValidationError
@ -232,7 +232,17 @@ def get_request_handler(
if raw_response.background is None: if raw_response.background is None:
raw_response.background = background_tasks raw_response.background = background_tasks
return raw_response return raw_response
response_data = await serialize_response( response_args: Dict[str, Any] = {"background": background_tasks}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else sub_response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if sub_response.status_code:
response_args["status_code"] = sub_response.status_code
content = await serialize_response(
field=response_field, field=response_field,
response_content=raw_response, response_content=raw_response,
include=response_model_include, include=response_model_include,
@ -243,15 +253,10 @@ def get_request_handler(
exclude_none=response_model_exclude_none, exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine, is_coroutine=is_coroutine,
) )
response_args: Dict[str, Any] = {"background": background_tasks} response = actual_response_class(content, **response_args)
# If status_code was set, use it, otherwise use the default from the if not is_body_allowed_for_status_code(status_code):
# response class, in the case of redirect it's 307 response.body = b""
if status_code is not None:
response_args["status_code"] = status_code
response = actual_response_class(response_data, **response_args)
response.headers.raw.extend(sub_response.headers.raw) response.headers.raw.extend(sub_response.headers.raw)
if sub_response.status_code:
response.status_code = sub_response.status_code
return response return response
return app return app
@ -377,8 +382,8 @@ class APIRoute(routing.Route):
status_code = int(status_code) status_code = int(status_code)
self.status_code = status_code self.status_code = status_code
if self.response_model: if self.response_model:
assert ( assert is_body_allowed_for_status_code(
status_code not in STATUS_CODES_WITH_NO_BODY status_code
), f"Status code {status_code} must not have a response body" ), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id response_name = "Response_" + self.unique_id
self.response_field = create_response_field( self.response_field = create_response_field(
@ -410,8 +415,8 @@ class APIRoute(routing.Route):
assert isinstance(response, dict), "An additional response must be a dict" assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model") model = response.get("model")
if model: if model:
assert ( assert is_body_allowed_for_status_code(
additional_status_code not in STATUS_CODES_WITH_NO_BODY additional_status_code
), f"Status code {additional_status_code} must not have a response body" ), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}" response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_response_field(name=response_name, type_=model) response_field = create_response_field(name=response_name, type_=model)

7
fastapi/utils.py

@ -18,6 +18,13 @@ if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute from .routing import APIRoute
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
if status_code is None:
return True
current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 304})
def get_model_definitions( def get_model_definitions(
*, *,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]], flat_models: Set[Union[Type[BaseModel], Type[Enum]]],

9
tests/test_response_code_no_body.py

@ -28,7 +28,7 @@ class JsonApiError(BaseModel):
responses={500: {"description": "Error", "model": JsonApiError}}, responses={500: {"description": "Error", "model": JsonApiError}},
) )
async def a(): async def a():
pass # pragma: no cover pass
@app.get("/b", responses={204: {"description": "No Content"}}) @app.get("/b", responses={204: {"description": "No Content"}})
@ -106,3 +106,10 @@ def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == openapi_schema assert response.json() == openapi_schema
def test_get_response():
response = client.get("/a")
assert response.status_code == 204, response.text
assert "content-length" not in response.headers
assert response.content == b""

Loading…
Cancel
Save