committed by
GitHub
7 changed files with 272 additions and 16 deletions
@ -0,0 +1,28 @@ |
|||
from typing import Optional |
|||
|
|||
from fastapi.concurrency import AsyncExitStack |
|||
from starlette.types import ASGIApp, Receive, Scope, Send |
|||
|
|||
|
|||
class AsyncExitStackMiddleware: |
|||
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None: |
|||
self.app = app |
|||
self.context_name = context_name |
|||
|
|||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|||
if AsyncExitStack: |
|||
dependency_exception: Optional[Exception] = None |
|||
async with AsyncExitStack() as stack: |
|||
scope[self.context_name] = stack |
|||
try: |
|||
await self.app(scope, receive, send) |
|||
except Exception as e: |
|||
dependency_exception = e |
|||
raise e |
|||
if dependency_exception: |
|||
# This exception was possibly handled by the dependency but it should |
|||
# still bubble up so that the ServerErrorMiddleware can return a 500 |
|||
# or the ExceptionMiddleware can catch and handle any other exceptions |
|||
raise dependency_exception |
|||
else: |
|||
await self.app(scope, receive, send) # pragma: no cover |
@ -0,0 +1,51 @@ |
|||
from contextvars import ContextVar |
|||
from typing import Any, Awaitable, Callable, Dict, Optional |
|||
|
|||
from fastapi import Depends, FastAPI, Request, Response |
|||
from fastapi.testclient import TestClient |
|||
|
|||
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar( |
|||
"legacy_request_state_context_var", default=None |
|||
) |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def set_up_request_state_dependency(): |
|||
request_state = {"user": "deadpond"} |
|||
contextvar_token = legacy_request_state_context_var.set(request_state) |
|||
yield request_state |
|||
legacy_request_state_context_var.reset(contextvar_token) |
|||
|
|||
|
|||
@app.middleware("http") |
|||
async def custom_middleware( |
|||
request: Request, call_next: Callable[[Request], Awaitable[Response]] |
|||
): |
|||
response = await call_next(request) |
|||
response.headers["custom"] = "foo" |
|||
return response |
|||
|
|||
|
|||
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)]) |
|||
def get_user(): |
|||
request_state = legacy_request_state_context_var.get() |
|||
assert request_state |
|||
return request_state["user"] |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_dependency_contextvars(): |
|||
""" |
|||
Check that custom middlewares don't affect the contextvar context for dependencies. |
|||
|
|||
The code before yield and the code after yield should be run in the same contextvar |
|||
context, so that request_state_context_var.reset(contextvar_token). |
|||
|
|||
If they are run in a different context, that raises an error. |
|||
""" |
|||
response = client.get("/user") |
|||
assert response.json() == "deadpond" |
|||
assert response.headers["custom"] == "foo" |
@ -0,0 +1,71 @@ |
|||
import pytest |
|||
from fastapi import Body, Depends, FastAPI, HTTPException |
|||
from fastapi.testclient import TestClient |
|||
|
|||
initial_fake_database = {"rick": "Rick Sanchez"} |
|||
|
|||
fake_database = initial_fake_database.copy() |
|||
|
|||
initial_state = {"except": False, "finally": False} |
|||
|
|||
state = initial_state.copy() |
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
async def get_database(): |
|||
temp_database = fake_database.copy() |
|||
try: |
|||
yield temp_database |
|||
fake_database.update(temp_database) |
|||
except HTTPException: |
|||
state["except"] = True |
|||
finally: |
|||
state["finally"] = True |
|||
|
|||
|
|||
@app.put("/invalid-user/{user_id}") |
|||
def put_invalid_user( |
|||
user_id: str, name: str = Body(...), db: dict = Depends(get_database) |
|||
): |
|||
db[user_id] = name |
|||
raise HTTPException(status_code=400, detail="Invalid user") |
|||
|
|||
|
|||
@app.put("/user/{user_id}") |
|||
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)): |
|||
db[user_id] = name |
|||
return {"message": "OK"} |
|||
|
|||
|
|||
@pytest.fixture(autouse=True) |
|||
def reset_state_and_db(): |
|||
global fake_database |
|||
global state |
|||
fake_database = initial_fake_database.copy() |
|||
state = initial_state.copy() |
|||
|
|||
|
|||
client = TestClient(app) |
|||
|
|||
|
|||
def test_dependency_gets_exception(): |
|||
assert state["except"] is False |
|||
assert state["finally"] is False |
|||
response = client.put("/invalid-user/rick", json="Morty") |
|||
assert response.status_code == 400, response.text |
|||
assert response.json() == {"detail": "Invalid user"} |
|||
assert state["except"] is True |
|||
assert state["finally"] is True |
|||
assert fake_database["rick"] == "Rick Sanchez" |
|||
|
|||
|
|||
def test_dependency_no_exception(): |
|||
assert state["except"] is False |
|||
assert state["finally"] is False |
|||
response = client.put("/user/rick", json="Morty") |
|||
assert response.status_code == 200, response.text |
|||
assert response.json() == {"message": "OK"} |
|||
assert state["except"] is False |
|||
assert state["finally"] is True |
|||
assert fake_database["rick"] == "Morty" |
Loading…
Reference in new issue