Browse Source

Update internal `AsyncExitStack` to fix context for dependencies with `yield` (#4575)

pull/4596/head
Sebastián Ramírez 3 years ago
committed by GitHub
parent
commit
9d56a3cb59
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
  2. 61
      fastapi/applications.py
  3. 28
      fastapi/middleware/asyncexitstack.py
  4. 44
      tests/test_dependency_contextmanager.py
  5. 51
      tests/test_dependency_contextvars.py
  6. 71
      tests/test_dependency_normal_exceptions.py
  7. 23
      tests/test_exception_handlers.py

10
docs/en/docs/tutorial/dependencies/dependencies-with-yield.md

@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca
It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**. It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`). The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore. So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
@ -138,9 +138,11 @@ participant tasks as Background tasks
end end
dep ->> operation: Run dependency, e.g. DB session dep ->> operation: Run dependency, e.g. DB session
opt raise opt raise
operation -->> handler: Raise HTTPException operation -->> dep: Raise HTTPException
dep -->> handler: Auto forward exception
handler -->> client: HTTP error response handler -->> client: HTTP error response
operation -->> dep: Raise other exception operation -->> dep: Raise other exception
dep -->> handler: Auto forward exception
end end
operation ->> client: Return response to client operation ->> client: Return response to client
Note over client,operation: Response is already sent, can't change it anymore Note over client,operation: Response is already sent, can't change it anymore
@ -162,9 +164,9 @@ participant tasks as Background tasks
After one of those responses is sent, no other response can be sent. After one of those responses is sent, no other response can be sent.
!!! tip !!! tip
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code. This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.
But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency. If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.
## Context Managers ## Context Managers

61
fastapi/applications.py

@ -2,7 +2,6 @@ from enum import Enum
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
from fastapi import routing from fastapi import routing
from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.encoders import DictIntStrAny, SetIntStr from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import ( from fastapi.exception_handlers import (
@ -11,6 +10,7 @@ from fastapi.exception_handlers import (
) )
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import ( from fastapi.openapi.docs import (
get_redoc_html, get_redoc_html,
get_swagger_ui_html, get_swagger_ui_html,
@ -21,8 +21,9 @@ from fastapi.params import Depends
from fastapi.types import DecoratedCallable from fastapi.types import DecoratedCallable
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.datastructures import State from starlette.datastructures import State
from starlette.exceptions import HTTPException from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute from starlette.routing import BaseRoute
@ -134,6 +135,55 @@ class FastAPI(Starlette):
self.openapi_schema: Optional[Dict[str, Any]] = None self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup() self.setup()
def build_middleware_stack(self) -> ASGIApp:
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
# inside of ExceptionMiddleware, inside of custom user middlewares
debug = self.debug
error_handler = None
exception_handlers = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
),
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
# contextvars.
# This needs to happen after user middlewares because those create a
# new contextvars context copy by using a new AnyIO task group.
# The initial part of dependencies with yield is executed in the
# FastAPI code, inside all the middlewares, but the teardown part
# (after yield) is executed in the AsyncExitStack in this middleware,
# if the AsyncExitStack lived outside of the custom middlewares and
# contextvars were set in a dependency with yield in that internal
# contextvars context, the values would not be available in the
# outside context of the AsyncExitStack.
# By putting the middleware and the AsyncExitStack here, inside all
# user middlewares, the code before and after yield in dependencies
# with yield is executed in the same contextvars context, so all values
# set in contextvars before yield is still available after yield as
# would be expected.
# Additionally, by having this AsyncExitStack here, after the
# ExceptionMiddleware, now dependencies can catch handled exceptions,
# e.g. HTTPException, to customize the teardown code (e.g. DB session
# rollback).
Middleware(AsyncExitStackMiddleware),
]
)
app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app
def openapi(self) -> Dict[str, Any]: def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema: if not self.openapi_schema:
self.openapi_schema = get_openapi( self.openapi_schema = get_openapi(
@ -206,12 +256,7 @@ class FastAPI(Starlette):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.root_path: if self.root_path:
scope["root_path"] = self.root_path scope["root_path"] = self.root_path
if AsyncExitStack: await super().__call__(scope, receive, send)
async with AsyncExitStack() as stack:
scope["fastapi_astack"] = stack
await super().__call__(scope, receive, send)
else:
await super().__call__(scope, receive, send) # pragma: no cover
def add_api_route( def add_api_route(
self, self,

28
fastapi/middleware/asyncexitstack.py

@ -0,0 +1,28 @@
from typing import Optional
from fastapi.concurrency import AsyncExitStack
from starlette.types import ASGIApp, Receive, Scope, Send
class AsyncExitStackMiddleware:
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
self.app = app
self.context_name = context_name
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if AsyncExitStack:
dependency_exception: Optional[Exception] = None
async with AsyncExitStack() as stack:
scope[self.context_name] = stack
try:
await self.app(scope, receive, send)
except Exception as e:
dependency_exception = e
raise e
if dependency_exception:
# This exception was possibly handled by the dependency but it should
# still bubble up so that the ServerErrorMiddleware can return a 500
# or the ExceptionMiddleware can catch and handle any other exceptions
raise dependency_exception
else:
await self.app(scope, receive, send) # pragma: no cover

44
tests/test_dependency_contextmanager.py

@ -235,7 +235,16 @@ def test_sync_raise_other():
assert "/sync_raise" not in errors assert "/sync_raise" not in errors
def test_async_raise(): def test_async_raise_raises():
with pytest.raises(AsyncDependencyError):
client.get("/async_raise")
assert state["/async_raise"] == "asyncgen raise finalized"
assert "/async_raise" in errors
errors.clear()
def test_async_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/async_raise") response = client.get("/async_raise")
assert response.status_code == 500, response.text assert response.status_code == 500, response.text
assert state["/async_raise"] == "asyncgen raise finalized" assert state["/async_raise"] == "asyncgen raise finalized"
@ -270,7 +279,16 @@ def test_background_tasks():
assert state["bg"] == "bg set - b: started b - a: started a" assert state["bg"] == "bg set - b: started b - a: started a"
def test_sync_raise(): def test_sync_raise_raises():
with pytest.raises(SyncDependencyError):
client.get("/sync_raise")
assert state["/sync_raise"] == "generator raise finalized"
assert "/sync_raise" in errors
errors.clear()
def test_sync_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_raise") response = client.get("/sync_raise")
assert response.status_code == 500, response.text assert response.status_code == 500, response.text
assert state["/sync_raise"] == "generator raise finalized" assert state["/sync_raise"] == "generator raise finalized"
@ -306,7 +324,16 @@ def test_sync_sync_raise_other():
assert "/sync_raise" not in errors assert "/sync_raise" not in errors
def test_sync_async_raise(): def test_sync_async_raise_raises():
with pytest.raises(AsyncDependencyError):
client.get("/sync_async_raise")
assert state["/async_raise"] == "asyncgen raise finalized"
assert "/async_raise" in errors
errors.clear()
def test_sync_async_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_async_raise") response = client.get("/sync_async_raise")
assert response.status_code == 500, response.text assert response.status_code == 500, response.text
assert state["/async_raise"] == "asyncgen raise finalized" assert state["/async_raise"] == "asyncgen raise finalized"
@ -314,7 +341,16 @@ def test_sync_async_raise():
errors.clear() errors.clear()
def test_sync_sync_raise(): def test_sync_sync_raise_raises():
with pytest.raises(SyncDependencyError):
client.get("/sync_sync_raise")
assert state["/sync_raise"] == "generator raise finalized"
assert "/sync_raise" in errors
errors.clear()
def test_sync_sync_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_sync_raise") response = client.get("/sync_sync_raise")
assert response.status_code == 500, response.text assert response.status_code == 500, response.text
assert state["/sync_raise"] == "generator raise finalized" assert state["/sync_raise"] == "generator raise finalized"

51
tests/test_dependency_contextvars.py

@ -0,0 +1,51 @@
from contextvars import ContextVar
from typing import Any, Awaitable, Callable, Dict, Optional
from fastapi import Depends, FastAPI, Request, Response
from fastapi.testclient import TestClient
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"legacy_request_state_context_var", default=None
)
app = FastAPI()
async def set_up_request_state_dependency():
request_state = {"user": "deadpond"}
contextvar_token = legacy_request_state_context_var.set(request_state)
yield request_state
legacy_request_state_context_var.reset(contextvar_token)
@app.middleware("http")
async def custom_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["custom"] = "foo"
return response
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
def get_user():
request_state = legacy_request_state_context_var.get()
assert request_state
return request_state["user"]
client = TestClient(app)
def test_dependency_contextvars():
"""
Check that custom middlewares don't affect the contextvar context for dependencies.
The code before yield and the code after yield should be run in the same contextvar
context, so that request_state_context_var.reset(contextvar_token).
If they are run in a different context, that raises an error.
"""
response = client.get("/user")
assert response.json() == "deadpond"
assert response.headers["custom"] == "foo"

71
tests/test_dependency_normal_exceptions.py

@ -0,0 +1,71 @@
import pytest
from fastapi import Body, Depends, FastAPI, HTTPException
from fastapi.testclient import TestClient
initial_fake_database = {"rick": "Rick Sanchez"}
fake_database = initial_fake_database.copy()
initial_state = {"except": False, "finally": False}
state = initial_state.copy()
app = FastAPI()
async def get_database():
temp_database = fake_database.copy()
try:
yield temp_database
fake_database.update(temp_database)
except HTTPException:
state["except"] = True
finally:
state["finally"] = True
@app.put("/invalid-user/{user_id}")
def put_invalid_user(
user_id: str, name: str = Body(...), db: dict = Depends(get_database)
):
db[user_id] = name
raise HTTPException(status_code=400, detail="Invalid user")
@app.put("/user/{user_id}")
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
db[user_id] = name
return {"message": "OK"}
@pytest.fixture(autouse=True)
def reset_state_and_db():
global fake_database
global state
fake_database = initial_fake_database.copy()
state = initial_state.copy()
client = TestClient(app)
def test_dependency_gets_exception():
assert state["except"] is False
assert state["finally"] is False
response = client.put("/invalid-user/rick", json="Morty")
assert response.status_code == 400, response.text
assert response.json() == {"detail": "Invalid user"}
assert state["except"] is True
assert state["finally"] is True
assert fake_database["rick"] == "Rick Sanchez"
def test_dependency_no_exception():
assert state["except"] is False
assert state["finally"] is False
response = client.put("/user/rick", json="Morty")
assert response.status_code == 200, response.text
assert response.json() == {"message": "OK"}
assert state["except"] is False
assert state["finally"] is True
assert fake_database["rick"] == "Morty"

23
tests/test_exception_handlers.py

@ -1,3 +1,4 @@
import pytest
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -12,10 +13,15 @@ def request_validation_exception_handler(request, exception):
return JSONResponse({"exception": "request-validation"}) return JSONResponse({"exception": "request-validation"})
def server_error_exception_handler(request, exception):
return JSONResponse(status_code=500, content={"exception": "server-error"})
app = FastAPI( app = FastAPI(
exception_handlers={ exception_handlers={
HTTPException: http_exception_handler, HTTPException: http_exception_handler,
RequestValidationError: request_validation_exception_handler, RequestValidationError: request_validation_exception_handler,
Exception: server_error_exception_handler,
} }
) )
@ -32,6 +38,11 @@ def route_with_request_validation_exception(param: int):
pass # pragma: no cover pass # pragma: no cover
@app.get("/server-error")
def route_with_server_error():
raise RuntimeError("Oops!")
def test_override_http_exception(): def test_override_http_exception():
response = client.get("/http-exception") response = client.get("/http-exception")
assert response.status_code == 200 assert response.status_code == 200
@ -42,3 +53,15 @@ def test_override_request_validation_exception():
response = client.get("/request-validation/invalid") response = client.get("/request-validation/invalid")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"exception": "request-validation"} assert response.json() == {"exception": "request-validation"}
def test_override_server_error_exception_raises():
with pytest.raises(RuntimeError):
client.get("/server-error")
def test_override_server_error_exception_response():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/server-error")
assert response.status_code == 500
assert response.json() == {"exception": "server-error"}

Loading…
Cancel
Save