Browse Source

Merge 0d27cd0eb8 into 460f8d2cc8

pull/15424/merge
Alfred Santacatalina Gea 12 hours ago
committed by GitHub
parent
commit
fde7c8668c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 31
      fastapi/routing.py
  2. 131
      tests/test_router_events.py

31
fastapi/routing.py

@ -213,12 +213,16 @@ def _merge_lifespan_context(
async def merged_lifespan(
app: AppType,
) -> AsyncIterator[Mapping[str, Any] | None]:
async with original_context(app) as maybe_original_state:
async with nested_context(app) as maybe_nested_state:
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
async with AsyncExitStack() as stack:
maybe_original_state = await stack.enter_async_context(
original_context(app)
)
maybe_nested_state = await stack.enter_async_context(nested_context(app))
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
return merged_lifespan # type: ignore[return-value] # ty: ignore[invalid-return-type]
@ -1823,10 +1827,17 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
if type(router.lifespan_context) is _DefaultLifespan:
return
if type(self.lifespan_context) is _DefaultLifespan:
self.lifespan_context = router.lifespan_context
else:
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
def get(
self,

131
tests/test_router_events.py

@ -171,6 +171,137 @@ def test_router_nested_lifespan_state(state: State) -> None:
assert state.sub_router_shutdown is True
def test_router_nested_lifespan_state_discard_default_lifespan_app(
state: State,
) -> None:
@asynccontextmanager
async def router_lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]:
state.router_startup = True
yield {"router": True}
state.router_shutdown = True
@asynccontextmanager
async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]:
state.sub_router_startup = True
yield {"sub_router": True}
state.sub_router_shutdown = True
sub_router = APIRouter(lifespan=subrouter_lifespan)
sub_router_lifespan_ctx = sub_router.lifespan_context
router = APIRouter(lifespan=router_lifespan)
router_lifespan_ctx = router.lifespan_context
router.include_router(sub_router)
assert router.lifespan_context is not router_lifespan_ctx, (
"Including a sub-router with a lifespan should change the router's lifespan context"
)
assert router.lifespan_context is not sub_router_lifespan_ctx, (
"New router lifespan context should not be the same as the sub-router's lifespan context, since the router should merge the lifespan contexts of all included sub-routers"
)
app = FastAPI()
app_lifespan_ctx = app.router.lifespan_context
app.include_router(router)
assert app.router.lifespan_context is not app_lifespan_ctx, (
"Including a router with a lifespan should change the app's lifespan context"
)
assert app.router.lifespan_context is router.lifespan_context, (
"New app lifespan context should be the same as the router's lifespan context, since the app has a default lifespan"
)
assert app.router.lifespan_context is not sub_router_lifespan_ctx, (
"New app lifespan context should not be the same as the sub-router's lifespan context, since the app should merge the lifespan contexts of all included routers"
)
@app.get("/")
def main(request: Request) -> dict[str, str]:
assert request.state.router
assert request.state.sub_router
return {"message": "Hello World"}
assert state.router_startup is False
assert state.sub_router_startup is False
assert state.router_shutdown is False
assert state.sub_router_shutdown is False
with TestClient(app) as client:
assert state.router_startup is True
assert state.sub_router_startup is True
assert state.router_shutdown is False
assert state.sub_router_shutdown is False
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
assert state.router_startup is True
assert state.sub_router_startup is True
assert state.router_shutdown is True
assert state.sub_router_shutdown is True
def test_router_nested_lifespan_state_discard_default_lifespan_child(
state: State,
) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]:
state.app_startup = True
yield {"app": True}
state.app_shutdown = True
router = APIRouter()
app = FastAPI(lifespan=lifespan)
app_lifespan_ctx = app.router.lifespan_context
app.include_router(router)
assert app.router.lifespan_context is app_lifespan_ctx, (
"Including a router without a lifespan should not change the app's lifespan context"
)
@app.get("/")
def main(request: Request) -> dict[str, str]:
assert request.state.app
return {"message": "Hello World"}
assert state.app_startup is False
assert state.app_shutdown is False
with TestClient(app) as client:
assert state.app_startup is True
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
assert state.app_shutdown is True
def test_router_nested_lifespan_state_no_lifespans(
state: State,
) -> None:
"""Test that if no lifespans are provided, the app still works and the state is empty."""
router = APIRouter()
app = FastAPI()
app_lifespan_ctx = app.router.lifespan_context
app.include_router(router)
assert app.router.lifespan_context is app_lifespan_ctx, (
"Including a router without a lifespan should not change the app's lifespan context"
)
@app.get("/")
def main(request: Request) -> dict[str, str]:
return {"message": "Hello World"}
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
def test_router_nested_lifespan_state_overriding_by_parent() -> None:
@asynccontextmanager
async def lifespan(

Loading…
Cancel
Save