From 9d56a3cb59d59896bc38293b9fa54ae69b7cd36c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 17 Feb 2022 13:40:12 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Update=20internal=20`AsyncExitStack?= =?UTF-8?q?`=20to=20fix=20context=20for=20dependencies=20with=20`yield`=20?= =?UTF-8?q?(#4575)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dependencies/dependencies-with-yield.md | 10 +-- fastapi/applications.py | 61 +++++++++++++--- fastapi/middleware/asyncexitstack.py | 28 ++++++++ tests/test_dependency_contextmanager.py | 44 ++++++++++-- tests/test_dependency_contextvars.py | 51 +++++++++++++ tests/test_dependency_normal_exceptions.py | 71 +++++++++++++++++++ tests/test_exception_handlers.py | 23 ++++++ 7 files changed, 272 insertions(+), 16 deletions(-) create mode 100644 fastapi/middleware/asyncexitstack.py create mode 100644 tests/test_dependency_contextvars.py create mode 100644 tests/test_dependency_normal_exceptions.py diff --git a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md index 82553afae..ac2e9cb8c 100644 --- a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md +++ b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md @@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**. -The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`). +The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`). So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore. @@ -138,9 +138,11 @@ participant tasks as Background tasks end dep ->> operation: Run dependency, e.g. DB session opt raise - operation -->> handler: Raise HTTPException + operation -->> dep: Raise HTTPException + dep -->> handler: Auto forward exception handler -->> client: HTTP error response operation -->> dep: Raise other exception + dep -->> handler: Auto forward exception end operation ->> client: Return response to client Note over client,operation: Response is already sent, can't change it anymore @@ -162,9 +164,9 @@ participant tasks as Background tasks After one of those responses is sent, no other response can be sent. !!! tip - This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code. + This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. - But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency. + If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server. ## Context Managers diff --git a/fastapi/applications.py b/fastapi/applications.py index dbfd76fb9..9fb78719c 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union from fastapi import routing -from fastapi.concurrency import AsyncExitStack from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.encoders import DictIntStrAny, SetIntStr from fastapi.exception_handlers import ( @@ -11,6 +10,7 @@ from fastapi.exception_handlers import ( ) from fastapi.exceptions import RequestValidationError from fastapi.logger import logger +from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware from fastapi.openapi.docs import ( get_redoc_html, get_swagger_ui_html, @@ -21,8 +21,9 @@ from fastapi.params import Depends from fastapi.types import DecoratedCallable from starlette.applications import Starlette from starlette.datastructures import State -from starlette.exceptions import HTTPException +from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.middleware import Middleware +from starlette.middleware.errors import ServerErrorMiddleware from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import BaseRoute @@ -134,6 +135,55 @@ class FastAPI(Starlette): self.openapi_schema: Optional[Dict[str, Any]] = None self.setup() + def build_middleware_stack(self) -> ASGIApp: + # Duplicate/override from Starlette to add AsyncExitStackMiddleware + # inside of ExceptionMiddleware, inside of custom user middlewares + debug = self.debug + error_handler = None + exception_handlers = {} + + for key, value in self.exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + middleware = ( + [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + self.user_middleware + + [ + Middleware( + ExceptionMiddleware, handlers=exception_handlers, debug=debug + ), + # Add FastAPI-specific AsyncExitStackMiddleware for dependencies with + # contextvars. + # This needs to happen after user middlewares because those create a + # new contextvars context copy by using a new AnyIO task group. + # The initial part of dependencies with yield is executed in the + # FastAPI code, inside all the middlewares, but the teardown part + # (after yield) is executed in the AsyncExitStack in this middleware, + # if the AsyncExitStack lived outside of the custom middlewares and + # contextvars were set in a dependency with yield in that internal + # contextvars context, the values would not be available in the + # outside context of the AsyncExitStack. + # By putting the middleware and the AsyncExitStack here, inside all + # user middlewares, the code before and after yield in dependencies + # with yield is executed in the same contextvars context, so all values + # set in contextvars before yield is still available after yield as + # would be expected. + # Additionally, by having this AsyncExitStack here, after the + # ExceptionMiddleware, now dependencies can catch handled exceptions, + # e.g. HTTPException, to customize the teardown code (e.g. DB session + # rollback). + Middleware(AsyncExitStackMiddleware), + ] + ) + + app = self.router + for cls, options in reversed(middleware): + app = cls(app=app, **options) + return app + def openapi(self) -> Dict[str, Any]: if not self.openapi_schema: self.openapi_schema = get_openapi( @@ -206,12 +256,7 @@ class FastAPI(Starlette): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.root_path: scope["root_path"] = self.root_path - if AsyncExitStack: - async with AsyncExitStack() as stack: - scope["fastapi_astack"] = stack - await super().__call__(scope, receive, send) - else: - await super().__call__(scope, receive, send) # pragma: no cover + await super().__call__(scope, receive, send) def add_api_route( self, diff --git a/fastapi/middleware/asyncexitstack.py b/fastapi/middleware/asyncexitstack.py new file mode 100644 index 000000000..503a68ac7 --- /dev/null +++ b/fastapi/middleware/asyncexitstack.py @@ -0,0 +1,28 @@ +from typing import Optional + +from fastapi.concurrency import AsyncExitStack +from starlette.types import ASGIApp, Receive, Scope, Send + + +class AsyncExitStackMiddleware: + def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None: + self.app = app + self.context_name = context_name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if AsyncExitStack: + dependency_exception: Optional[Exception] = None + async with AsyncExitStack() as stack: + scope[self.context_name] = stack + try: + await self.app(scope, receive, send) + except Exception as e: + dependency_exception = e + raise e + if dependency_exception: + # This exception was possibly handled by the dependency but it should + # still bubble up so that the ServerErrorMiddleware can return a 500 + # or the ExceptionMiddleware can catch and handle any other exceptions + raise dependency_exception + else: + await self.app(scope, receive, send) # pragma: no cover diff --git a/tests/test_dependency_contextmanager.py b/tests/test_dependency_contextmanager.py index 3e42b47f7..03ef56c4d 100644 --- a/tests/test_dependency_contextmanager.py +++ b/tests/test_dependency_contextmanager.py @@ -235,7 +235,16 @@ def test_sync_raise_other(): assert "/sync_raise" not in errors -def test_async_raise(): +def test_async_raise_raises(): + with pytest.raises(AsyncDependencyError): + client.get("/async_raise") + assert state["/async_raise"] == "asyncgen raise finalized" + assert "/async_raise" in errors + errors.clear() + + +def test_async_raise_server_error(): + client = TestClient(app, raise_server_exceptions=False) response = client.get("/async_raise") assert response.status_code == 500, response.text assert state["/async_raise"] == "asyncgen raise finalized" @@ -270,7 +279,16 @@ def test_background_tasks(): assert state["bg"] == "bg set - b: started b - a: started a" -def test_sync_raise(): +def test_sync_raise_raises(): + with pytest.raises(SyncDependencyError): + client.get("/sync_raise") + assert state["/sync_raise"] == "generator raise finalized" + assert "/sync_raise" in errors + errors.clear() + + +def test_sync_raise_server_error(): + client = TestClient(app, raise_server_exceptions=False) response = client.get("/sync_raise") assert response.status_code == 500, response.text assert state["/sync_raise"] == "generator raise finalized" @@ -306,7 +324,16 @@ def test_sync_sync_raise_other(): assert "/sync_raise" not in errors -def test_sync_async_raise(): +def test_sync_async_raise_raises(): + with pytest.raises(AsyncDependencyError): + client.get("/sync_async_raise") + assert state["/async_raise"] == "asyncgen raise finalized" + assert "/async_raise" in errors + errors.clear() + + +def test_sync_async_raise_server_error(): + client = TestClient(app, raise_server_exceptions=False) response = client.get("/sync_async_raise") assert response.status_code == 500, response.text assert state["/async_raise"] == "asyncgen raise finalized" @@ -314,7 +341,16 @@ def test_sync_async_raise(): errors.clear() -def test_sync_sync_raise(): +def test_sync_sync_raise_raises(): + with pytest.raises(SyncDependencyError): + client.get("/sync_sync_raise") + assert state["/sync_raise"] == "generator raise finalized" + assert "/sync_raise" in errors + errors.clear() + + +def test_sync_sync_raise_server_error(): + client = TestClient(app, raise_server_exceptions=False) response = client.get("/sync_sync_raise") assert response.status_code == 500, response.text assert state["/sync_raise"] == "generator raise finalized" diff --git a/tests/test_dependency_contextvars.py b/tests/test_dependency_contextvars.py new file mode 100644 index 000000000..076802df8 --- /dev/null +++ b/tests/test_dependency_contextvars.py @@ -0,0 +1,51 @@ +from contextvars import ContextVar +from typing import Any, Awaitable, Callable, Dict, Optional + +from fastapi import Depends, FastAPI, Request, Response +from fastapi.testclient import TestClient + +legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar( + "legacy_request_state_context_var", default=None +) + +app = FastAPI() + + +async def set_up_request_state_dependency(): + request_state = {"user": "deadpond"} + contextvar_token = legacy_request_state_context_var.set(request_state) + yield request_state + legacy_request_state_context_var.reset(contextvar_token) + + +@app.middleware("http") +async def custom_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +): + response = await call_next(request) + response.headers["custom"] = "foo" + return response + + +@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)]) +def get_user(): + request_state = legacy_request_state_context_var.get() + assert request_state + return request_state["user"] + + +client = TestClient(app) + + +def test_dependency_contextvars(): + """ + Check that custom middlewares don't affect the contextvar context for dependencies. + + The code before yield and the code after yield should be run in the same contextvar + context, so that request_state_context_var.reset(contextvar_token). + + If they are run in a different context, that raises an error. + """ + response = client.get("/user") + assert response.json() == "deadpond" + assert response.headers["custom"] == "foo" diff --git a/tests/test_dependency_normal_exceptions.py b/tests/test_dependency_normal_exceptions.py new file mode 100644 index 000000000..49a19f460 --- /dev/null +++ b/tests/test_dependency_normal_exceptions.py @@ -0,0 +1,71 @@ +import pytest +from fastapi import Body, Depends, FastAPI, HTTPException +from fastapi.testclient import TestClient + +initial_fake_database = {"rick": "Rick Sanchez"} + +fake_database = initial_fake_database.copy() + +initial_state = {"except": False, "finally": False} + +state = initial_state.copy() + +app = FastAPI() + + +async def get_database(): + temp_database = fake_database.copy() + try: + yield temp_database + fake_database.update(temp_database) + except HTTPException: + state["except"] = True + finally: + state["finally"] = True + + +@app.put("/invalid-user/{user_id}") +def put_invalid_user( + user_id: str, name: str = Body(...), db: dict = Depends(get_database) +): + db[user_id] = name + raise HTTPException(status_code=400, detail="Invalid user") + + +@app.put("/user/{user_id}") +def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)): + db[user_id] = name + return {"message": "OK"} + + +@pytest.fixture(autouse=True) +def reset_state_and_db(): + global fake_database + global state + fake_database = initial_fake_database.copy() + state = initial_state.copy() + + +client = TestClient(app) + + +def test_dependency_gets_exception(): + assert state["except"] is False + assert state["finally"] is False + response = client.put("/invalid-user/rick", json="Morty") + assert response.status_code == 400, response.text + assert response.json() == {"detail": "Invalid user"} + assert state["except"] is True + assert state["finally"] is True + assert fake_database["rick"] == "Rick Sanchez" + + +def test_dependency_no_exception(): + assert state["except"] is False + assert state["finally"] is False + response = client.put("/user/rick", json="Morty") + assert response.status_code == 200, response.text + assert response.json() == {"message": "OK"} + assert state["except"] is False + assert state["finally"] is True + assert fake_database["rick"] == "Morty" diff --git a/tests/test_exception_handlers.py b/tests/test_exception_handlers.py index 6153f7ab9..67a4becec 100644 --- a/tests/test_exception_handlers.py +++ b/tests/test_exception_handlers.py @@ -1,3 +1,4 @@ +import pytest from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient @@ -12,10 +13,15 @@ def request_validation_exception_handler(request, exception): return JSONResponse({"exception": "request-validation"}) +def server_error_exception_handler(request, exception): + return JSONResponse(status_code=500, content={"exception": "server-error"}) + + app = FastAPI( exception_handlers={ HTTPException: http_exception_handler, RequestValidationError: request_validation_exception_handler, + Exception: server_error_exception_handler, } ) @@ -32,6 +38,11 @@ def route_with_request_validation_exception(param: int): pass # pragma: no cover +@app.get("/server-error") +def route_with_server_error(): + raise RuntimeError("Oops!") + + def test_override_http_exception(): response = client.get("/http-exception") assert response.status_code == 200 @@ -42,3 +53,15 @@ def test_override_request_validation_exception(): response = client.get("/request-validation/invalid") assert response.status_code == 200 assert response.json() == {"exception": "request-validation"} + + +def test_override_server_error_exception_raises(): + with pytest.raises(RuntimeError): + client.get("/server-error") + + +def test_override_server_error_exception_response(): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/server-error") + assert response.status_code == 500 + assert response.json() == {"exception": "server-error"}