diff --git a/fastapi/routing.py b/fastapi/routing.py index f620ced5f..765d1cfea 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,5 +1,6 @@ import dataclasses import email.message +import functools import inspect import json import sys @@ -8,6 +9,7 @@ from enum import Enum, IntEnum from typing import ( Any, AsyncIterator, + Awaitable, Callable, Collection, Coroutine, @@ -59,6 +61,8 @@ from fastapi.utils import ( ) from pydantic import BaseModel from starlette import routing +from starlette._exception_handler import wrap_app_handling_exceptions +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request @@ -68,11 +72,9 @@ from starlette.routing import ( Match, compile_path, get_name, - request_response, - websocket_session, ) from starlette.routing import Mount as Mount # noqa -from starlette.types import AppType, ASGIApp, Lifespan, Scope +from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket from typing_extensions import Annotated, Doc, deprecated @@ -246,111 +248,119 @@ def get_request_handler( async def app(request: Request) -> Response: response: Union[Response, None] = None - async with AsyncExitStack() as file_stack: - try: - body: Any = None - if body_field: - if is_body_form: - body = await request.form() - file_stack.push_async_callback(body.close) - else: - body_bytes = await request.body() - if body_bytes: - json_body: Any = Undefined - content_type_value = request.headers.get("content-type") - if not content_type_value: - json_body = await request.json() - else: - message = email.message.Message() - message["content-type"] = content_type_value - if message.get_content_maintype() == "application": - subtype = message.get_content_subtype() - if subtype == "json" or subtype.endswith("+json"): - json_body = await request.json() - if json_body != Undefined: - body = json_body - else: - body = body_bytes - except json.JSONDecodeError as e: - validation_error = RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - } - ], - body=e.doc, - ) - raise validation_error from e - except HTTPException: - # If a middleware raises an HTTPException, it should be raised again - raise - except Exception as e: - http_error = HTTPException( - status_code=400, detail="There was an error parsing the body" + file_stack = request.scope.get("fastapi_middleware_astack") + assert isinstance( + file_stack, AsyncExitStack + ), "fastapi_astack not found in request scope" + + # Read body and auto-close files + try: + body: Any = None + if body_field: + if is_body_form: + body = await request.form() + file_stack.push_async_callback(body.close) + else: + body_bytes = await request.body() + if body_bytes: + json_body: Any = Undefined + content_type_value = request.headers.get("content-type") + if not content_type_value: + json_body = await request.json() + else: + message = email.message.Message() + message["content-type"] = content_type_value + if message.get_content_maintype() == "application": + subtype = message.get_content_subtype() + if subtype == "json" or subtype.endswith("+json"): + json_body = await request.json() + if json_body != Undefined: + body = json_body + else: + body = body_bytes + except json.JSONDecodeError as e: + validation_error = RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + } + ], + body=e.doc, + ) + raise validation_error from e + except HTTPException: + # If a middleware raises an HTTPException, it should be raised again + raise + except Exception as e: + http_error = HTTPException( + status_code=400, detail="There was an error parsing the body" + ) + raise http_error from e + + # Solve dependencies and run path operation function, auto-closing dependencies + errors: List[Any] = [] + async_exit_stack = request.scope.get("fastapi_inner_astack") + assert isinstance( + async_exit_stack, AsyncExitStack + ), "fastapi_inner_astack not found in request scope" + solved_result = await solve_dependencies( + request=request, + dependant=dependant, + body=body, + dependency_overrides_provider=dependency_overrides_provider, + async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, + ) + errors = solved_result.errors + if not errors: + raw_response = await run_endpoint_function( + dependant=dependant, + values=solved_result.values, + is_coroutine=is_coroutine, + ) + if isinstance(raw_response, Response): + if raw_response.background is None: + raw_response.background = solved_result.background_tasks + response = raw_response + else: + response_args: Dict[str, Any] = { + "background": solved_result.background_tasks + } + # If status_code was set, use it, otherwise use the default from the + # response class, in the case of redirect it's 307 + current_status_code = ( + status_code if status_code else solved_result.response.status_code ) - raise http_error from e - errors: List[Any] = [] - async with AsyncExitStack() as async_exit_stack: - solved_result = await solve_dependencies( - request=request, - dependant=dependant, - body=body, - dependency_overrides_provider=dependency_overrides_provider, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, + if current_status_code is not None: + response_args["status_code"] = current_status_code + if solved_result.response.status_code: + response_args["status_code"] = solved_result.response.status_code + content = await serialize_response( + field=response_field, + response_content=raw_response, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + is_coroutine=is_coroutine, ) - errors = solved_result.errors - if not errors: - raw_response = await run_endpoint_function( - dependant=dependant, - values=solved_result.values, - is_coroutine=is_coroutine, - ) - if isinstance(raw_response, Response): - if raw_response.background is None: - raw_response.background = solved_result.background_tasks - response = raw_response - else: - response_args: Dict[str, Any] = { - "background": solved_result.background_tasks - } - # If status_code was set, use it, otherwise use the default from the - # response class, in the case of redirect it's 307 - current_status_code = ( - status_code - if status_code - else solved_result.response.status_code - ) - if current_status_code is not None: - response_args["status_code"] = current_status_code - if solved_result.response.status_code: - response_args["status_code"] = ( - solved_result.response.status_code - ) - content = await serialize_response( - field=response_field, - response_content=raw_response, - include=response_model_include, - exclude=response_model_exclude, - by_alias=response_model_by_alias, - exclude_unset=response_model_exclude_unset, - exclude_defaults=response_model_exclude_defaults, - exclude_none=response_model_exclude_none, - is_coroutine=is_coroutine, - ) - response = actual_response_class(content, **response_args) - if not is_body_allowed_for_status_code(response.status_code): - response.body = b"" - response.headers.raw.extend(solved_result.response.headers.raw) - if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body - ) - raise validation_error + response = actual_response_class(content, **response_args) + if not is_body_allowed_for_status_code(response.status_code): + response.body = b"" + response.headers.raw.extend(solved_result.response.headers.raw) + if errors: + validation_error = RequestValidationError( + _normalize_errors(errors), body=body + ) + raise validation_error + + # Return response if response is None: raise FastAPIError( "No response object was returned. There's a high chance that the " @@ -370,24 +380,23 @@ def get_websocket_app( embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: - async with AsyncExitStack() as async_exit_stack: - # TODO: remove this scope later, after a few releases - # This scope fastapi_astack is no longer used by FastAPI, kept for - # compatibility, just in case - websocket.scope["fastapi_astack"] = async_exit_stack - solved_result = await solve_dependencies( - request=websocket, - dependant=dependant, - dependency_overrides_provider=dependency_overrides_provider, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, + async_exit_stack = websocket.scope.get("fastapi_inner_astack") + assert isinstance( + async_exit_stack, AsyncExitStack + ), "fastapi_inner_astack not found in request scope" + solved_result = await solve_dependencies( + request=websocket, + dependant=dependant, + dependency_overrides_provider=dependency_overrides_provider, + async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, + ) + if solved_result.errors: + raise WebSocketRequestValidationError( + _normalize_errors(solved_result.errors) ) - if solved_result.errors: - raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) - ) - assert dependant.call is not None, "dependant.call must be a function" - await dependant.call(**solved_result.values) + assert dependant.call is not None, "dependant.call must be a function" + await dependant.call(**solved_result.values) return app diff --git a/tests/test_tutorial/test_dependencies/test_tutorial008c.py b/tests/test_tutorial/test_dependencies/test_tutorial008c.py index 11e96bf46..369b0a221 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial008c.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial008c.py @@ -40,7 +40,7 @@ def test_fastapi_error(mod: ModuleType): client = TestClient(mod.app) with pytest.raises(FastAPIError) as exc_info: client.get("/items/portal-gun") - assert "No response object was returned" in exc_info.value.args[0] + assert "raising an exception and a dependency with yield" in exc_info.value.args[0] def test_internal_server_error(mod: ModuleType):