From e329d78f866a12893699f786f1209a666e1688e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 29 Sep 2025 12:29:38 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20support=20for=20`Streaming?= =?UTF-8?q?Response`s=20with=20dependencies=20with=20`yield`=20or=20`Uploa?= =?UTF-8?q?dFile`s,=20close=20after=20the=20response=20is=20done=20(#14099?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../en/docs/advanced/advanced-dependencies.md | 88 +++++ .../dependencies/dependencies-with-yield.md | 53 +-- docs_src/dependencies/tutorial013_an_py310.py | 38 ++ docs_src/dependencies/tutorial014_an_py310.py | 39 +++ fastapi/applications.py | 53 ++- fastapi/middleware/asyncexitstack.py | 18 + fastapi/routing.py | 329 +++++++++++------- tests/test_dependency_after_yield_raise.py | 69 ++++ .../test_dependency_after_yield_streaming.py | 130 +++++++ .../test_dependency_after_yield_websockets.py | 79 +++++ tests/test_dependency_contextmanager.py | 11 +- ..._dependency_yield_except_httpexception.py} | 0 tests/test_route_scope.py | 2 +- .../test_dependencies/test_tutorial008c.py | 2 +- 14 files changed, 729 insertions(+), 182 deletions(-) create mode 100644 docs_src/dependencies/tutorial013_an_py310.py create mode 100644 docs_src/dependencies/tutorial014_an_py310.py create mode 100644 fastapi/middleware/asyncexitstack.py create mode 100644 tests/test_dependency_after_yield_raise.py create mode 100644 tests/test_dependency_after_yield_streaming.py create mode 100644 tests/test_dependency_after_yield_websockets.py rename tests/{test_dependency_normal_exceptions.py => test_dependency_yield_except_httpexception.py} (100%) diff --git a/docs/en/docs/advanced/advanced-dependencies.md b/docs/en/docs/advanced/advanced-dependencies.md index c71c11404..e0404b389 100644 --- a/docs/en/docs/advanced/advanced-dependencies.md +++ b/docs/en/docs/advanced/advanced-dependencies.md @@ -63,3 +63,91 @@ In the chapters about security, there are utility functions that are implemented If you understood all this, you already know how those utility tools for security work underneath. /// + +## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks } + +/// warning + +You most probably don't need these technical details. + +These details are useful mainly if you had a FastAPI application older than 0.118.0 and you are facing issues with dependencies with `yield`. + +/// + +Dependencies with `yield` have evolved over time to account for the different use cases and to fix some issues, here's a summary of what has changed. + +### Dependencies with `yield` and `StreamingResponse`, Technical Details { #dependencies-with-yield-and-streamingresponse-technical-details } + +Before FastAPI 0.118.0, if you used a dependency with `yield`, it would run the exit code after the *path operation function* returned but right before sending the response. + +The intention was to avoid holding resources for longer than necessary, waiting for the response to travel through the network. + +This change also meant that if you returned a `StreamingResponse`, the exit code of the dependency with `yield` would have been already run. + +For example, if you had a database session in a dependency with `yield`, the `StreamingResponse` would not be able to use that session while streaming data because the session would have already been closed in the exit code after `yield`. + +This behavior was reverted in 0.118.0, to make the exit code after `yield` be executed after the response is sent. + +/// info + +As you will see below, this is very similar to the behavior before version 0.106.0, but with several improvements and bug fixes for corner cases. + +/// + +#### Use Cases with Early Exit Code { #use-cases-with-early-exit-code } + +There are some use cases with specific conditions that could benefit from the old behavior of running the exit code of dependencies with `yield` before sending the response. + +For example, imagine you have code that uses a database session in a dependency with `yield` only to verify a user, but the database session is never used again in the *path operation function*, only in the dependency, **and** the response takes a long time to be sent, like a `StreamingResponse` that sends data slowly, but for some reason doesn't use the database. + +In this case, the database session would be held until the response is finished being sent, but if you don't use it, then it wouldn't be necessary to hold it. + +Here's how it could look like: + +{* ../../docs_src/dependencies/tutorial013_an_py310.py *} + +The exit code, the automatic closing of the `Session` in: + +{* ../../docs_src/dependencies/tutorial013_an_py310.py ln[19:21] *} + +...would be run after the the response finishes sending the slow data: + +{* ../../docs_src/dependencies/tutorial013_an_py310.py ln[30:38] hl[31:33] *} + +But as `generate_stream()` doesn't use the database session, it is not really necessary to keep the session open while sending the response. + +If you have this specific use case using SQLModel (or SQLAlchemy), you could explicitly close the session after you don't need it anymore: + +{* ../../docs_src/dependencies/tutorial014_an_py310.py ln[24:28] hl[28] *} + +That way the session would release the database connection, so other requests could use it. + +If you have a different use case that needs to exit early from a dependency with `yield`, please create a GitHub Discussion Question with your specific use case and why you would benefit from having early closing for dependencies with `yield`. + +If there are compelling use cases for early closing in dependencies with `yield`, I would consider adding a new way to opt in to early closing. + +### Dependencies with `yield` and `except`, Technical Details { #dependencies-with-yield-and-except-technical-details } + +Before FastAPI 0.110.0, if you used a dependency with `yield`, and then you captured an exception with `except` in that dependency, and you didn't raise the exception again, the exception would be automatically raised/forwarded to any exception handlers or the internal server error handler. + +This was changed in version 0.110.0 to fix unhandled memory consumption from forwarded exceptions without a handler (internal server errors), and to make it consistent with the behavior of regular Python code. + +### Background Tasks and Dependencies with `yield`, Technical Details { #background-tasks-and-dependencies-with-yield-technical-details } + +Before FastAPI 0.106.0, raising exceptions after `yield` was not possible, the exit code in dependencies with `yield` was executed *after* the response was sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} would have already run. + +This was designed this way mainly to allow using the same objects "yielded" by dependencies inside of background tasks, because the exit code would be executed after the background tasks were finished. + +This was changed in FastAPI 0.106.0 with the intention to not hold resources while waiting for the response to travel through the network. + +/// tip + +Additionally, a background task is normally an independent set of logic that should be handled separately, with its own resources (e.g. its own database connection). + +So, this way you will probably have cleaner code. + +/// + +If you used to rely on this behavior, now you should create the resources for background tasks inside the background task itself, and use internally only data that doesn't depend on the resources of dependencies with `yield`. + +For example, instead of using the same database session, you would create a new database session inside of the background task, and you would obtain the objects from the database using this new session. And then instead of passing the object from the database as a parameter to the background task function, you would pass the ID of that object and then obtain the object again inside the background task function. diff --git a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md index 2e2a6a8e3..adc1afa8d 100644 --- a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md +++ b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md @@ -35,7 +35,7 @@ The yielded value is what is injected into *path operations* and other dependenc {* ../../docs_src/dependencies/tutorial007.py hl[4] *} -The code following the `yield` statement is executed after creating the response but before sending it: +The code following the `yield` statement is executed after the response: {* ../../docs_src/dependencies/tutorial007.py hl[5:6] *} @@ -51,7 +51,7 @@ You can use `async` or regular functions. If you use a `try` block in a dependency with `yield`, you'll receive any exception that was thrown when using the dependency. -For example, if some code at some point in the middle, in another dependency or in a *path operation*, made a database transaction "rollback" or create any other error, you will receive the exception in your dependency. +For example, if some code at some point in the middle, in another dependency or in a *path operation*, made a database transaction "rollback" or created any other exception, you would receive the exception in your dependency. So, you can look for that specific exception inside the dependency with `except SomeException`. @@ -95,9 +95,11 @@ This works thanks to Python's ASGIApp: + # Duplicate/override from Starlette to add AsyncExitStackMiddleware + # inside of ExceptionMiddleware, inside of custom user middlewares + debug = self.debug + error_handler = None + exception_handlers: dict[Any, ExceptionHandler] = {} + + for key, value in self.exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + middleware = ( + [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + self.user_middleware + + [ + Middleware( + ExceptionMiddleware, handlers=exception_handlers, debug=debug + ), + # Add FastAPI-specific AsyncExitStackMiddleware for closing files. + # Before this was also used for closing dependencies with yield but + # those now have their own AsyncExitStack, to properly support + # streaming responses while keeping compatibility with the previous + # versions (as of writing 0.117.1) that allowed doing + # except HTTPException inside a dependency with yield. + # This needs to happen after user middlewares because those create a + # new contextvars context copy by using a new AnyIO task group. + # This AsyncExitStack preserves the context for contextvars, not + # strictly necessary for closing files but it was one of the original + # intentions. + # If the AsyncExitStack lived outside of the custom middlewares and + # contextvars were set, for example in a dependency with 'yield' + # in that internal contextvars context, the values would not be + # available in the outer context of the AsyncExitStack. + # By placing the middleware and the AsyncExitStack here, inside all + # user middlewares, the same context is used. + # This is currently not needed, only for closing files, but used to be + # important when dependencies with yield were closed here. + Middleware(AsyncExitStackMiddleware), + ] + ) + + app = self.router + for cls, args, kwargs in reversed(middleware): + app = cls(app, *args, **kwargs) + return app + def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI diff --git a/fastapi/middleware/asyncexitstack.py b/fastapi/middleware/asyncexitstack.py new file mode 100644 index 000000000..4ce3f5a62 --- /dev/null +++ b/fastapi/middleware/asyncexitstack.py @@ -0,0 +1,18 @@ +from contextlib import AsyncExitStack + +from starlette.types import ASGIApp, Receive, Scope, Send + + +# Used mainly to close files after the request is done, dependencies are closed +# in their own AsyncExitStack +class AsyncExitStackMiddleware: + def __init__( + self, app: ASGIApp, context_name: str = "fastapi_middleware_astack" + ) -> None: + self.app = app + self.context_name = context_name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async with AsyncExitStack() as stack: + scope[self.context_name] = stack + await self.app(scope, receive, send) diff --git a/fastapi/routing.py b/fastapi/routing.py index f620ced5f..65f739d95 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,5 +1,6 @@ import dataclasses import email.message +import functools import inspect import json import sys @@ -8,6 +9,7 @@ from enum import Enum, IntEnum from typing import ( Any, AsyncIterator, + Awaitable, Callable, Collection, Coroutine, @@ -59,6 +61,8 @@ from fastapi.utils import ( ) from pydantic import BaseModel from starlette import routing +from starlette._exception_handler import wrap_app_handling_exceptions +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request @@ -68,11 +72,9 @@ from starlette.routing import ( Match, compile_path, get_name, - request_response, - websocket_session, ) from starlette.routing import Mount as Mount # noqa -from starlette.types import AppType, ASGIApp, Lifespan, Scope +from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket from typing_extensions import Annotated, Doc, deprecated @@ -82,6 +84,73 @@ else: # pragma: no cover from asyncio import iscoroutinefunction +# Copy of starlette.routing.request_response modified to include the +# dependencies' AsyncExitStack +def request_response( + func: Callable[[Request], Union[Awaitable[Response], Response]], +) -> ASGIApp: + """ + Takes a function or coroutine `func(request) -> response`, + and returns an ASGI application. + """ + f: Callable[[Request], Awaitable[Response]] = ( + func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore + ) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive, send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + # Starts customization + response_awaited = False + async with AsyncExitStack() as stack: + scope["fastapi_inner_astack"] = stack + # Same as in Starlette + response = await f(request) + await response(scope, receive, send) + # Continues customization + response_awaited = True + if not response_awaited: + raise FastAPIError( + "Response not awaited. There's a high chance that the " + "application code is raising an exception and a dependency with yield " + "has a block with a bare except, or a block with except Exception, " + "and is not raising the exception again. Read more about it in the " + "docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except" + ) + + # Same as in Starlette + await wrap_app_handling_exceptions(app, request)(scope, receive, send) + + return app + + +# Copy of starlette.routing.websocket_session modified to include the +# dependencies' AsyncExitStack +def websocket_session( + func: Callable[[WebSocket], Awaitable[None]], +) -> ASGIApp: + """ + Takes a coroutine `func(session)`, and returns an ASGI application. + """ + # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + session = WebSocket(scope, receive=receive, send=send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + # Starts customization + async with AsyncExitStack() as stack: + scope["fastapi_inner_astack"] = stack + # Same as in Starlette + await func(session) + + # Same as in Starlette + await wrap_app_handling_exceptions(app, session)(scope, receive, send) + + return app + + def _prepare_response_content( res: Any, *, @@ -246,119 +315,120 @@ def get_request_handler( async def app(request: Request) -> Response: response: Union[Response, None] = None - async with AsyncExitStack() as file_stack: - try: - body: Any = None - if body_field: - if is_body_form: - body = await request.form() - file_stack.push_async_callback(body.close) - else: - body_bytes = await request.body() - if body_bytes: - json_body: Any = Undefined - content_type_value = request.headers.get("content-type") - if not content_type_value: - json_body = await request.json() - else: - message = email.message.Message() - message["content-type"] = content_type_value - if message.get_content_maintype() == "application": - subtype = message.get_content_subtype() - if subtype == "json" or subtype.endswith("+json"): - json_body = await request.json() - if json_body != Undefined: - body = json_body - else: - body = body_bytes - except json.JSONDecodeError as e: - validation_error = RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - } - ], - body=e.doc, - ) - raise validation_error from e - except HTTPException: - # If a middleware raises an HTTPException, it should be raised again - raise - except Exception as e: - http_error = HTTPException( - status_code=400, detail="There was an error parsing the body" - ) - raise http_error from e - errors: List[Any] = [] - async with AsyncExitStack() as async_exit_stack: - solved_result = await solve_dependencies( - request=request, - dependant=dependant, - body=body, - dependency_overrides_provider=dependency_overrides_provider, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, + file_stack = request.scope.get("fastapi_middleware_astack") + assert isinstance(file_stack, AsyncExitStack), ( + "fastapi_middleware_astack not found in request scope" + ) + + # Read body and auto-close files + try: + body: Any = None + if body_field: + if is_body_form: + body = await request.form() + file_stack.push_async_callback(body.close) + else: + body_bytes = await request.body() + if body_bytes: + json_body: Any = Undefined + content_type_value = request.headers.get("content-type") + if not content_type_value: + json_body = await request.json() + else: + message = email.message.Message() + message["content-type"] = content_type_value + if message.get_content_maintype() == "application": + subtype = message.get_content_subtype() + if subtype == "json" or subtype.endswith("+json"): + json_body = await request.json() + if json_body != Undefined: + body = json_body + else: + body = body_bytes + except json.JSONDecodeError as e: + validation_error = RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + } + ], + body=e.doc, + ) + raise validation_error from e + except HTTPException: + # If a middleware raises an HTTPException, it should be raised again + raise + except Exception as e: + http_error = HTTPException( + status_code=400, detail="There was an error parsing the body" + ) + raise http_error from e + + # Solve dependencies and run path operation function, auto-closing dependencies + errors: List[Any] = [] + async_exit_stack = request.scope.get("fastapi_inner_astack") + assert isinstance(async_exit_stack, AsyncExitStack), ( + "fastapi_inner_astack not found in request scope" + ) + solved_result = await solve_dependencies( + request=request, + dependant=dependant, + body=body, + dependency_overrides_provider=dependency_overrides_provider, + async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, + ) + errors = solved_result.errors + if not errors: + raw_response = await run_endpoint_function( + dependant=dependant, + values=solved_result.values, + is_coroutine=is_coroutine, + ) + if isinstance(raw_response, Response): + if raw_response.background is None: + raw_response.background = solved_result.background_tasks + response = raw_response + else: + response_args: Dict[str, Any] = { + "background": solved_result.background_tasks + } + # If status_code was set, use it, otherwise use the default from the + # response class, in the case of redirect it's 307 + current_status_code = ( + status_code if status_code else solved_result.response.status_code ) - errors = solved_result.errors - if not errors: - raw_response = await run_endpoint_function( - dependant=dependant, - values=solved_result.values, - is_coroutine=is_coroutine, - ) - if isinstance(raw_response, Response): - if raw_response.background is None: - raw_response.background = solved_result.background_tasks - response = raw_response - else: - response_args: Dict[str, Any] = { - "background": solved_result.background_tasks - } - # If status_code was set, use it, otherwise use the default from the - # response class, in the case of redirect it's 307 - current_status_code = ( - status_code - if status_code - else solved_result.response.status_code - ) - if current_status_code is not None: - response_args["status_code"] = current_status_code - if solved_result.response.status_code: - response_args["status_code"] = ( - solved_result.response.status_code - ) - content = await serialize_response( - field=response_field, - response_content=raw_response, - include=response_model_include, - exclude=response_model_exclude, - by_alias=response_model_by_alias, - exclude_unset=response_model_exclude_unset, - exclude_defaults=response_model_exclude_defaults, - exclude_none=response_model_exclude_none, - is_coroutine=is_coroutine, - ) - response = actual_response_class(content, **response_args) - if not is_body_allowed_for_status_code(response.status_code): - response.body = b"" - response.headers.raw.extend(solved_result.response.headers.raw) - if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body + if current_status_code is not None: + response_args["status_code"] = current_status_code + if solved_result.response.status_code: + response_args["status_code"] = solved_result.response.status_code + content = await serialize_response( + field=response_field, + response_content=raw_response, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + is_coroutine=is_coroutine, ) - raise validation_error - if response is None: - raise FastAPIError( - "No response object was returned. There's a high chance that the " - "application code is raising an exception and a dependency with yield " - "has a block with a bare except, or a block with except Exception, " - "and is not raising the exception again. Read more about it in the " - "docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except" + response = actual_response_class(content, **response_args) + if not is_body_allowed_for_status_code(response.status_code): + response.body = b"" + response.headers.raw.extend(solved_result.response.headers.raw) + if errors: + validation_error = RequestValidationError( + _normalize_errors(errors), body=body ) + raise validation_error + + # Return response + assert response return response return app @@ -370,24 +440,23 @@ def get_websocket_app( embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: - async with AsyncExitStack() as async_exit_stack: - # TODO: remove this scope later, after a few releases - # This scope fastapi_astack is no longer used by FastAPI, kept for - # compatibility, just in case - websocket.scope["fastapi_astack"] = async_exit_stack - solved_result = await solve_dependencies( - request=websocket, - dependant=dependant, - dependency_overrides_provider=dependency_overrides_provider, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, + async_exit_stack = websocket.scope.get("fastapi_inner_astack") + assert isinstance(async_exit_stack, AsyncExitStack), ( + "fastapi_inner_astack not found in request scope" + ) + solved_result = await solve_dependencies( + request=websocket, + dependant=dependant, + dependency_overrides_provider=dependency_overrides_provider, + async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, + ) + if solved_result.errors: + raise WebSocketRequestValidationError( + _normalize_errors(solved_result.errors) ) - if solved_result.errors: - raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) - ) - assert dependant.call is not None, "dependant.call must be a function" - await dependant.call(**solved_result.values) + assert dependant.call is not None, "dependant.call must be a function" + await dependant.call(**solved_result.values) return app diff --git a/tests/test_dependency_after_yield_raise.py b/tests/test_dependency_after_yield_raise.py new file mode 100644 index 000000000..b560dc36f --- /dev/null +++ b/tests/test_dependency_after_yield_raise.py @@ -0,0 +1,69 @@ +from typing import Any + +import pytest +from fastapi import Depends, FastAPI, HTTPException +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +class CustomError(Exception): + pass + + +def catching_dep() -> Any: + try: + yield "s" + except CustomError as err: + raise HTTPException(status_code=418, detail="Session error") from err + + +def broken_dep() -> Any: + yield "s" + raise ValueError("Broken after yield") + + +app = FastAPI() + + +@app.get("/catching") +def catching(d: Annotated[str, Depends(catching_dep)]) -> Any: + raise CustomError("Simulated error during streaming") + + +@app.get("/broken") +def broken(d: Annotated[str, Depends(broken_dep)]) -> Any: + return {"message": "all good?"} + + +client = TestClient(app) + + +def test_catching(): + response = client.get("/catching") + assert response.status_code == 418 + assert response.json() == {"detail": "Session error"} + + +def test_broken_raise(): + with pytest.raises(ValueError, match="Broken after yield"): + client.get("/broken") + + +def test_broken_no_raise(): + """ + When a dependency with yield raises after the yield (not in an except), the + response is already "successfully" sent back to the client, but there's still + an error in the server afterwards, an exception is raised and captured or shown + in the server logs. + """ + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/broken") + assert response.status_code == 200 + assert response.json() == {"message": "all good?"} + + +def test_broken_return_finishes(): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/broken") + assert response.status_code == 200 + assert response.json() == {"message": "all good?"} diff --git a/tests/test_dependency_after_yield_streaming.py b/tests/test_dependency_after_yield_streaming.py new file mode 100644 index 000000000..7e1c8822b --- /dev/null +++ b/tests/test_dependency_after_yield_streaming.py @@ -0,0 +1,130 @@ +from contextlib import contextmanager +from typing import Any, Generator + +import pytest +from fastapi import Depends, FastAPI +from fastapi.responses import StreamingResponse +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +class Session: + def __init__(self) -> None: + self.data = ["foo", "bar", "baz"] + self.open = True + + def __iter__(self) -> Generator[str, None, None]: + for item in self.data: + if self.open: + yield item + else: + raise ValueError("Session closed") + + +@contextmanager +def acquire_session() -> Generator[Session, None, None]: + session = Session() + try: + yield session + finally: + session.open = False + + +def dep_session() -> Any: + with acquire_session() as s: + yield s + + +def broken_dep_session() -> Any: + with acquire_session() as s: + s.open = False + yield s + + +SessionDep = Annotated[Session, Depends(dep_session)] +BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] + +app = FastAPI() + + +@app.get("/data") +def get_data(session: SessionDep) -> Any: + data = list(session) + return data + + +@app.get("/stream-simple") +def get_stream_simple(session: SessionDep) -> Any: + def iter_data(): + yield from ["x", "y", "z"] + + return StreamingResponse(iter_data()) + + +@app.get("/stream-session") +def get_stream_session(session: SessionDep) -> Any: + def iter_data(): + yield from session + + return StreamingResponse(iter_data()) + + +@app.get("/broken-session-data") +def get_broken_session_data(session: BrokenSessionDep) -> Any: + return list(session) + + +@app.get("/broken-session-stream") +def get_broken_session_stream(session: BrokenSessionDep) -> Any: + def iter_data(): + yield from session + + return StreamingResponse(iter_data()) + + +client = TestClient(app) + + +def test_regular_no_stream(): + response = client.get("/data") + assert response.json() == ["foo", "bar", "baz"] + + +def test_stream_simple(): + response = client.get("/stream-simple") + assert response.text == "xyz" + + +def test_stream_session(): + response = client.get("/stream-session") + assert response.text == "foobarbaz" + + +def test_broken_session_data(): + with pytest.raises(ValueError, match="Session closed"): + client.get("/broken-session-data") + + +def test_broken_session_data_no_raise(): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/broken-session-data") + assert response.status_code == 500 + assert response.text == "Internal Server Error" + + +def test_broken_session_stream_raise(): + # Can raise ValueError on Pydantic v2 and ExceptionGroup on Pydantic v1 + with pytest.raises((ValueError, Exception)): + client.get("/broken-session-stream") + + +def test_broken_session_stream_no_raise(): + """ + When a dependency with yield raises after the streaming response already started + the 200 status code is already sent, but there's still an error in the server + afterwards, an exception is raised and captured or shown in the server logs. + """ + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/broken-session-stream") + assert response.status_code == 200 + assert response.text == "" diff --git a/tests/test_dependency_after_yield_websockets.py b/tests/test_dependency_after_yield_websockets.py new file mode 100644 index 000000000..7c323c338 --- /dev/null +++ b/tests/test_dependency_after_yield_websockets.py @@ -0,0 +1,79 @@ +from contextlib import contextmanager +from typing import Any, Generator + +import pytest +from fastapi import Depends, FastAPI, WebSocket +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +class Session: + def __init__(self) -> None: + self.data = ["foo", "bar", "baz"] + self.open = True + + def __iter__(self) -> Generator[str, None, None]: + for item in self.data: + if self.open: + yield item + else: + raise ValueError("Session closed") + + +@contextmanager +def acquire_session() -> Generator[Session, None, None]: + session = Session() + try: + yield session + finally: + session.open = False + + +def dep_session() -> Any: + with acquire_session() as s: + yield s + + +def broken_dep_session() -> Any: + with acquire_session() as s: + s.open = False + yield s + + +SessionDep = Annotated[Session, Depends(dep_session)] +BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] + +app = FastAPI() + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket, session: SessionDep): + await websocket.accept() + for item in session: + await websocket.send_text(f"{item}") + + +@app.websocket("/ws-broken") +async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): + await websocket.accept() + for item in session: + await websocket.send_text(f"{item}") # pragma no cover + + +client = TestClient(app) + + +def test_websocket_dependency_after_yield(): + with client.websocket_connect("/ws") as websocket: + data = websocket.receive_text() + assert data == "foo" + data = websocket.receive_text() + assert data == "bar" + data = websocket.receive_text() + assert data == "baz" + + +def test_websocket_dependency_after_yield_broken(): + with pytest.raises(ValueError, match="Session closed"): + with client.websocket_connect("/ws-broken"): + pass # pragma no cover diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py index 039c423b9..02c10458c 100644 --- a/tests/test_dependency_contextmanager.py +++ b/tests/test_dependency_contextmanager.py @@ -286,12 +286,12 @@ def test_background_tasks(): assert data["context_a"] == "started a" assert data["bg"] == "not set" middleware_state = json.loads(response.headers["x-state"]) - assert middleware_state["context_b"] == "finished b with a: started a" - assert middleware_state["context_a"] == "finished a" + assert middleware_state["context_b"] == "started b" + assert middleware_state["context_a"] == "started a" assert middleware_state["bg"] == "not set" assert state["context_b"] == "finished b with a: started a" assert state["context_a"] == "finished a" - assert state["bg"] == "bg set - b: finished b with a: started a - a: finished a" + assert state["bg"] == "bg set - b: started b - a: started a" def test_sync_raise_raises(): @@ -397,7 +397,4 @@ def test_sync_background_tasks(): assert data["sync_bg"] == "not set" assert state["context_b"] == "finished b with a: started a" assert state["context_a"] == "finished a" - assert ( - state["sync_bg"] - == "sync_bg set - b: finished b with a: started a - a: finished a" - ) + assert state["sync_bg"] == "sync_bg set - b: started b - a: started a" diff --git a/tests/test_dependency_normal_exceptions.py b/tests/test_dependency_yield_except_httpexception.py similarity index 100% rename from tests/test_dependency_normal_exceptions.py rename to tests/test_dependency_yield_except_httpexception.py diff --git a/tests/test_route_scope.py b/tests/test_route_scope.py index 2021c828f..792ea66c3 100644 --- a/tests/test_route_scope.py +++ b/tests/test_route_scope.py @@ -47,4 +47,4 @@ def test_websocket(): def test_websocket_invalid_path_doesnt_match(): with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/itemsx/portal-gun"): - pass + pass # pragma: no cover diff --git a/tests/test_tutorial/test_dependencies/test_tutorial008c.py b/tests/test_tutorial/test_dependencies/test_tutorial008c.py index 11e96bf46..369b0a221 100644 --- a/tests/test_tutorial/test_dependencies/test_tutorial008c.py +++ b/tests/test_tutorial/test_dependencies/test_tutorial008c.py @@ -40,7 +40,7 @@ def test_fastapi_error(mod: ModuleType): client = TestClient(mod.app) with pytest.raises(FastAPIError) as exc_info: client.get("/items/portal-gun") - assert "No response object was returned" in exc_info.value.args[0] + assert "raising an exception and a dependency with yield" in exc_info.value.args[0] def test_internal_server_error(mod: ModuleType):