diff --git a/fastapi/routing.py b/fastapi/routing.py index 86e303602..627c03ec2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -33,6 +33,7 @@ from fastapi._compat import ( from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( + SolvedDependency, _should_embed_body_fields, get_body_field, get_dependant, @@ -240,52 +241,12 @@ def get_request_handler( async def app(request: Request) -> Response: response: Union[Response, None] = None async with AsyncExitStack() as file_stack: - try: - body: Any = None - if body_field: - if is_body_form: - body = await request.form() - file_stack.push_async_callback(body.close) - else: - body_bytes = await request.body() - if body_bytes: - json_body: Any = Undefined - content_type_value = request.headers.get("content-type") - if not content_type_value: - json_body = await request.json() - else: - message = email.message.Message() - message["content-type"] = content_type_value - if message.get_content_maintype() == "application": - subtype = message.get_content_subtype() - if subtype == "json" or subtype.endswith("+json"): - json_body = await request.json() - if json_body != Undefined: - body = json_body - else: - body = body_bytes - except json.JSONDecodeError as e: - validation_error = RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - } - ], - body=e.doc, - ) - raise validation_error from e - except HTTPException: - # If a middleware raises an HTTPException, it should be raised again - raise - except Exception as e: - http_error = HTTPException( - status_code=400, detail="There was an error parsing the body" - ) - raise http_error from e + body = await parse_request_body( + request, + body_field=body_field, + is_body_form=is_body_form, + file_stack=file_stack, + ) errors: List[Any] = [] async with AsyncExitStack() as async_exit_stack: solved_result = await solve_dependencies( @@ -303,42 +264,10 @@ def get_request_handler( values=solved_result.values, is_coroutine=is_coroutine, ) - if isinstance(raw_response, Response): - if raw_response.background is None: - raw_response.background = solved_result.background_tasks - response = raw_response - else: - response_args: Dict[str, Any] = { - "background": solved_result.background_tasks - } - # If status_code was set, use it, otherwise use the default from the - # response class, in the case of redirect it's 307 - current_status_code = ( - status_code - if status_code - else solved_result.response.status_code - ) - if current_status_code is not None: - response_args["status_code"] = current_status_code - if solved_result.response.status_code: - response_args["status_code"] = ( - solved_result.response.status_code - ) - content = await serialize_response( - field=response_field, - response_content=raw_response, - include=response_model_include, - exclude=response_model_exclude, - by_alias=response_model_by_alias, - exclude_unset=response_model_exclude_unset, - exclude_defaults=response_model_exclude_defaults, - exclude_none=response_model_exclude_none, - is_coroutine=is_coroutine, - ) - response = actual_response_class(content, **response_args) - if not is_body_allowed_for_status_code(response.status_code): - response.body = b"" - response.headers.raw.extend(solved_result.response.headers.raw) + response = await prepare_response_object( + raw_response, + solved_result, + ) if errors: validation_error = RequestValidationError( _normalize_errors(errors), body=body @@ -354,9 +283,103 @@ def get_request_handler( ) return response + async def prepare_response_object( + raw_response: Any, + solved_result: SolvedDependency, + ) -> Response: + if isinstance(raw_response, Response): + if raw_response.background is None: + raw_response.background = solved_result.background_tasks + response = raw_response + else: + response_args: Dict[str, Any] = { + "background": solved_result.background_tasks + } + # If status_code was set, use it, otherwise use the default from the + # response class, in the case of redirect it's 307 + current_status_code = ( + status_code if status_code else solved_result.response.status_code + ) + if current_status_code is not None: + response_args["status_code"] = current_status_code + if solved_result.response.status_code: + response_args["status_code"] = solved_result.response.status_code + content = await serialize_response( + field=response_field, + response_content=raw_response, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + is_coroutine=is_coroutine, + ) + response = actual_response_class(content, **response_args) + if not is_body_allowed_for_status_code(response.status_code): + response.body = b"" + response.headers.raw.extend(solved_result.response.headers.raw) + return response + return app +async def parse_request_body( + request: Request, + *, + body_field: Optional[ModelField], + is_body_form: bool | None, + file_stack: AsyncExitStack, +) -> Any: + try: + body: Any = None + if body_field: + if is_body_form: + body = await request.form() + file_stack.push_async_callback(body.close) + else: + body_bytes = await request.body() + if body_bytes: + json_body: Any = Undefined + content_type_value = request.headers.get("content-type") + if not content_type_value: + json_body = await request.json() + else: + message = email.message.Message() + message["content-type"] = content_type_value + if message.get_content_maintype() == "application": + subtype = message.get_content_subtype() + if subtype == "json" or subtype.endswith("+json"): + json_body = await request.json() + if json_body != Undefined: + body = json_body + else: + body = body_bytes + except json.JSONDecodeError as e: + validation_error = RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + } + ], + body=e.doc, + ) + raise validation_error from e + except HTTPException: + # If a middleware raises an HTTPException, it should be raised again + raise + except Exception as e: + http_error = HTTPException( + status_code=400, detail="There was an error parsing the body" + ) + raise http_error from e + return body + + def get_websocket_app( dependant: Dependant, dependency_overrides_provider: Optional[Any] = None,