From c43120258fa89bc20d6f8ee671b6ead9ab223fc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 14 Jul 2022 13:19:42 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20removing=20body=20from=20s?= =?UTF-8?q?tatus=20codes=20that=20do=20not=20support=20it=20(#5145)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/openapi/constants.py | 1 - fastapi/openapi/utils.py | 12 ++++------- fastapi/routing.py | 33 +++++++++++++++++------------ fastapi/utils.py | 7 ++++++ tests/test_response_code_no_body.py | 9 +++++++- 5 files changed, 38 insertions(+), 24 deletions(-) diff --git a/fastapi/openapi/constants.py b/fastapi/openapi/constants.py index 3e69e5524..1897ad750 100644 --- a/fastapi/openapi/constants.py +++ b/fastapi/openapi/constants.py @@ -1,3 +1,2 @@ METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} -STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304} REF_PREFIX = "#/components/schemas/" diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 4eb727bd4..5d3d95c24 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -9,11 +9,7 @@ from fastapi.datastructures import DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import get_flat_dependant, get_flat_params from fastapi.encoders import jsonable_encoder -from fastapi.openapi.constants import ( - METHODS_WITH_BODY, - REF_PREFIX, - STATUS_CODES_WITH_NO_BODY, -) +from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX from fastapi.openapi.models import OpenAPI from fastapi.params import Body, Param from fastapi.responses import Response @@ -21,6 +17,7 @@ from fastapi.utils import ( deep_dict_update, generate_operation_id_for_path, get_model_definitions, + is_body_allowed_for_status_code, ) from pydantic import BaseModel from pydantic.fields import ModelField, Undefined @@ -265,9 +262,8 @@ def get_openapi_path( operation.setdefault("responses", {}).setdefault(status_code, {})[ "description" ] = route.response_description - if ( - route_response_media_type - and route.status_code not in STATUS_CODES_WITH_NO_BODY + if route_response_media_type and is_body_allowed_for_status_code( + route.status_code ): response_schema = {"type": "string"} if lenient_issubclass(current_response_class, JSONResponse): diff --git a/fastapi/routing.py b/fastapi/routing.py index a6542c15a..6f1a8e900 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -29,13 +29,13 @@ from fastapi.dependencies.utils import ( ) from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError -from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY from fastapi.types import DecoratedCallable from fastapi.utils import ( create_cloned_field, create_response_field, generate_unique_id, get_value_or_default, + is_body_allowed_for_status_code, ) from pydantic import BaseModel from pydantic.error_wrappers import ErrorWrapper, ValidationError @@ -232,7 +232,17 @@ def get_request_handler( if raw_response.background is None: raw_response.background = background_tasks return raw_response - response_data = await serialize_response( + response_args: Dict[str, Any] = {"background": background_tasks} + # If status_code was set, use it, otherwise use the default from the + # response class, in the case of redirect it's 307 + current_status_code = ( + status_code if status_code else sub_response.status_code + ) + if current_status_code is not None: + response_args["status_code"] = current_status_code + if sub_response.status_code: + response_args["status_code"] = sub_response.status_code + content = await serialize_response( field=response_field, response_content=raw_response, include=response_model_include, @@ -243,15 +253,10 @@ def get_request_handler( exclude_none=response_model_exclude_none, is_coroutine=is_coroutine, ) - response_args: Dict[str, Any] = {"background": background_tasks} - # If status_code was set, use it, otherwise use the default from the - # response class, in the case of redirect it's 307 - if status_code is not None: - response_args["status_code"] = status_code - response = actual_response_class(response_data, **response_args) + response = actual_response_class(content, **response_args) + if not is_body_allowed_for_status_code(status_code): + response.body = b"" response.headers.raw.extend(sub_response.headers.raw) - if sub_response.status_code: - response.status_code = sub_response.status_code return response return app @@ -377,8 +382,8 @@ class APIRoute(routing.Route): status_code = int(status_code) self.status_code = status_code if self.response_model: - assert ( - status_code not in STATUS_CODES_WITH_NO_BODY + assert is_body_allowed_for_status_code( + status_code ), f"Status code {status_code} must not have a response body" response_name = "Response_" + self.unique_id self.response_field = create_response_field( @@ -410,8 +415,8 @@ class APIRoute(routing.Route): assert isinstance(response, dict), "An additional response must be a dict" model = response.get("model") if model: - assert ( - additional_status_code not in STATUS_CODES_WITH_NO_BODY + assert is_body_allowed_for_status_code( + additional_status_code ), f"Status code {additional_status_code} must not have a response body" response_name = f"Response_{additional_status_code}_{self.unique_id}" response_field = create_response_field(name=response_name, type_=model) diff --git a/fastapi/utils.py b/fastapi/utils.py index a7e135bca..887d57c90 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -18,6 +18,13 @@ if TYPE_CHECKING: # pragma: nocover from .routing import APIRoute +def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: + if status_code is None: + return True + current_status_code = int(status_code) + return not (current_status_code < 200 or current_status_code in {204, 304}) + + def get_model_definitions( *, flat_models: Set[Union[Type[BaseModel], Type[Enum]]], diff --git a/tests/test_response_code_no_body.py b/tests/test_response_code_no_body.py index 45e2fabc7..6d9b5c333 100644 --- a/tests/test_response_code_no_body.py +++ b/tests/test_response_code_no_body.py @@ -28,7 +28,7 @@ class JsonApiError(BaseModel): responses={500: {"description": "Error", "model": JsonApiError}}, ) async def a(): - pass # pragma: no cover + pass @app.get("/b", responses={204: {"description": "No Content"}}) @@ -106,3 +106,10 @@ def test_openapi_schema(): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema + + +def test_get_response(): + response = client.get("/a") + assert response.status_code == 204, response.text + assert "content-length" not in response.headers + assert response.content == b""