committed by
GitHub
7 changed files with 272 additions and 16 deletions
@ -0,0 +1,28 @@ |
|||||
|
from typing import Optional |
||||
|
|
||||
|
from fastapi.concurrency import AsyncExitStack |
||||
|
from starlette.types import ASGIApp, Receive, Scope, Send |
||||
|
|
||||
|
|
||||
|
class AsyncExitStackMiddleware: |
||||
|
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None: |
||||
|
self.app = app |
||||
|
self.context_name = context_name |
||||
|
|
||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
||||
|
if AsyncExitStack: |
||||
|
dependency_exception: Optional[Exception] = None |
||||
|
async with AsyncExitStack() as stack: |
||||
|
scope[self.context_name] = stack |
||||
|
try: |
||||
|
await self.app(scope, receive, send) |
||||
|
except Exception as e: |
||||
|
dependency_exception = e |
||||
|
raise e |
||||
|
if dependency_exception: |
||||
|
# This exception was possibly handled by the dependency but it should |
||||
|
# still bubble up so that the ServerErrorMiddleware can return a 500 |
||||
|
# or the ExceptionMiddleware can catch and handle any other exceptions |
||||
|
raise dependency_exception |
||||
|
else: |
||||
|
await self.app(scope, receive, send) # pragma: no cover |
@ -0,0 +1,51 @@ |
|||||
|
from contextvars import ContextVar |
||||
|
from typing import Any, Awaitable, Callable, Dict, Optional |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, Request, Response |
||||
|
from fastapi.testclient import TestClient |
||||
|
|
||||
|
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar( |
||||
|
"legacy_request_state_context_var", default=None |
||||
|
) |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def set_up_request_state_dependency(): |
||||
|
request_state = {"user": "deadpond"} |
||||
|
contextvar_token = legacy_request_state_context_var.set(request_state) |
||||
|
yield request_state |
||||
|
legacy_request_state_context_var.reset(contextvar_token) |
||||
|
|
||||
|
|
||||
|
@app.middleware("http") |
||||
|
async def custom_middleware( |
||||
|
request: Request, call_next: Callable[[Request], Awaitable[Response]] |
||||
|
): |
||||
|
response = await call_next(request) |
||||
|
response.headers["custom"] = "foo" |
||||
|
return response |
||||
|
|
||||
|
|
||||
|
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)]) |
||||
|
def get_user(): |
||||
|
request_state = legacy_request_state_context_var.get() |
||||
|
assert request_state |
||||
|
return request_state["user"] |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_dependency_contextvars(): |
||||
|
""" |
||||
|
Check that custom middlewares don't affect the contextvar context for dependencies. |
||||
|
|
||||
|
The code before yield and the code after yield should be run in the same contextvar |
||||
|
context, so that request_state_context_var.reset(contextvar_token). |
||||
|
|
||||
|
If they are run in a different context, that raises an error. |
||||
|
""" |
||||
|
response = client.get("/user") |
||||
|
assert response.json() == "deadpond" |
||||
|
assert response.headers["custom"] == "foo" |
@ -0,0 +1,71 @@ |
|||||
|
import pytest |
||||
|
from fastapi import Body, Depends, FastAPI, HTTPException |
||||
|
from fastapi.testclient import TestClient |
||||
|
|
||||
|
initial_fake_database = {"rick": "Rick Sanchez"} |
||||
|
|
||||
|
fake_database = initial_fake_database.copy() |
||||
|
|
||||
|
initial_state = {"except": False, "finally": False} |
||||
|
|
||||
|
state = initial_state.copy() |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
async def get_database(): |
||||
|
temp_database = fake_database.copy() |
||||
|
try: |
||||
|
yield temp_database |
||||
|
fake_database.update(temp_database) |
||||
|
except HTTPException: |
||||
|
state["except"] = True |
||||
|
finally: |
||||
|
state["finally"] = True |
||||
|
|
||||
|
|
||||
|
@app.put("/invalid-user/{user_id}") |
||||
|
def put_invalid_user( |
||||
|
user_id: str, name: str = Body(...), db: dict = Depends(get_database) |
||||
|
): |
||||
|
db[user_id] = name |
||||
|
raise HTTPException(status_code=400, detail="Invalid user") |
||||
|
|
||||
|
|
||||
|
@app.put("/user/{user_id}") |
||||
|
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)): |
||||
|
db[user_id] = name |
||||
|
return {"message": "OK"} |
||||
|
|
||||
|
|
||||
|
@pytest.fixture(autouse=True) |
||||
|
def reset_state_and_db(): |
||||
|
global fake_database |
||||
|
global state |
||||
|
fake_database = initial_fake_database.copy() |
||||
|
state = initial_state.copy() |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_dependency_gets_exception(): |
||||
|
assert state["except"] is False |
||||
|
assert state["finally"] is False |
||||
|
response = client.put("/invalid-user/rick", json="Morty") |
||||
|
assert response.status_code == 400, response.text |
||||
|
assert response.json() == {"detail": "Invalid user"} |
||||
|
assert state["except"] is True |
||||
|
assert state["finally"] is True |
||||
|
assert fake_database["rick"] == "Rick Sanchez" |
||||
|
|
||||
|
|
||||
|
def test_dependency_no_exception(): |
||||
|
assert state["except"] is False |
||||
|
assert state["finally"] is False |
||||
|
response = client.put("/user/rick", json="Morty") |
||||
|
assert response.status_code == 200, response.text |
||||
|
assert response.json() == {"message": "OK"} |
||||
|
assert state["except"] is False |
||||
|
assert state["finally"] is True |
||||
|
assert fake_database["rick"] == "Morty" |
Loading…
Reference in new issue