From 0f1b333e99bc4591eb8ed1f8cae88845dc5e5033 Mon Sep 17 00:00:00 2001 From: Alfred Date: Wed, 22 Apr 2026 09:49:01 +0200 Subject: [PATCH] Discard default lifespans before merging --- fastapi/routing.py | 31 +++++++---- tests/test_router_events.py | 101 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 10 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 36acb6b89d..a276f038e0 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -213,12 +213,16 @@ def _merge_lifespan_context( async def merged_lifespan( app: AppType, ) -> AsyncIterator[Mapping[str, Any] | None]: - async with original_context(app) as maybe_original_state: - async with nested_context(app) as maybe_nested_state: - if maybe_nested_state is None and maybe_original_state is None: - yield None # old ASGI compatibility - else: - yield {**(maybe_nested_state or {}), **(maybe_original_state or {})} + async with AsyncExitStack() as stack: + maybe_original_state = await stack.enter_async_context( + original_context(app) + ) + maybe_nested_state = await stack.enter_async_context(nested_context(app)) + + if maybe_nested_state is None and maybe_original_state is None: + yield None # old ASGI compatibility + else: + yield {**(maybe_nested_state or {}), **(maybe_original_state or {})} return merged_lifespan # type: ignore[return-value] @@ -1823,10 +1827,17 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) - self.lifespan_context = _merge_lifespan_context( - self.lifespan_context, - router.lifespan_context, - ) + + if type(router.lifespan_context) is _DefaultLifespan: + return + + if type(self.lifespan_context) is _DefaultLifespan: + self.lifespan_context = router.lifespan_context + else: + self.lifespan_context = _merge_lifespan_context( + self.lifespan_context, + router.lifespan_context, + ) def get( self, diff --git a/tests/test_router_events.py b/tests/test_router_events.py index 7869a7afcd..b4c9bb4c41 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -171,6 +171,107 @@ def test_router_nested_lifespan_state(state: State) -> None: assert state.sub_router_shutdown is True +def test_router_nested_lifespan_state_discard_default_lifespan_app( + state: State, +) -> None: + + @asynccontextmanager + async def router_lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]: + state.router_startup = True + yield {"router": True} + state.router_shutdown = True + + @asynccontextmanager + async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]: + state.sub_router_startup = True + yield {"sub_router": True} + state.sub_router_shutdown = True + + sub_router = APIRouter(lifespan=subrouter_lifespan) + + router = APIRouter(lifespan=router_lifespan) + router.include_router(sub_router) + + app = FastAPI() + app.include_router(router) + + @app.get("/") + def main(request: Request) -> dict[str, str]: + assert request.state.router + assert request.state.sub_router + return {"message": "Hello World"} + + assert state.router_startup is False + assert state.sub_router_startup is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + + with TestClient(app) as client: + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.router_shutdown is True + assert state.sub_router_shutdown is True + + +def test_router_nested_lifespan_state_discard_default_lifespan_child( + state: State, +) -> None: + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]: + state.app_startup = True + yield {"app": True} + state.app_shutdown = True + + router = APIRouter() + + app = FastAPI(lifespan=lifespan) + app.include_router(router) + + @app.get("/") + def main(request: Request) -> dict[str, str]: + assert request.state.app + return {"message": "Hello World"} + + assert state.app_startup is False + assert state.app_shutdown is False + + with TestClient(app) as client: + assert state.app_startup is True + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + + assert state.app_shutdown is True + + +def test_router_nested_lifespan_state_no_lifespans( + state: State, +) -> None: + """Test that if no lifespans are provided, the app still works and the state is empty.""" + router = APIRouter() + + app = FastAPI() + app.include_router(router) + + @app.get("/") + def main(request: Request) -> dict[str, str]: + return {"message": "Hello World"} + + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + + def test_router_nested_lifespan_state_overriding_by_parent() -> None: @asynccontextmanager async def lifespan(