Browse Source

Discard default lifespans before merging

pull/15403/head
Alfred 2 months ago
parent
commit
0f1b333e99
  1. 31
      fastapi/routing.py
  2. 101
      tests/test_router_events.py

31
fastapi/routing.py

@ -213,12 +213,16 @@ def _merge_lifespan_context(
async def merged_lifespan(
app: AppType,
) -> AsyncIterator[Mapping[str, Any] | None]:
async with original_context(app) as maybe_original_state:
async with nested_context(app) as maybe_nested_state:
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
async with AsyncExitStack() as stack:
maybe_original_state = await stack.enter_async_context(
original_context(app)
)
maybe_nested_state = await stack.enter_async_context(nested_context(app))
if maybe_nested_state is None and maybe_original_state is None:
yield None # old ASGI compatibility
else:
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
return merged_lifespan # type: ignore[return-value]
@ -1823,10 +1827,17 @@ class APIRouter(routing.Router):
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
if type(router.lifespan_context) is _DefaultLifespan:
return
if type(self.lifespan_context) is _DefaultLifespan:
self.lifespan_context = router.lifespan_context
else:
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)
def get(
self,

101
tests/test_router_events.py

@ -171,6 +171,107 @@ def test_router_nested_lifespan_state(state: State) -> None:
assert state.sub_router_shutdown is True
def test_router_nested_lifespan_state_discard_default_lifespan_app(
state: State,
) -> None:
@asynccontextmanager
async def router_lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]:
state.router_startup = True
yield {"router": True}
state.router_shutdown = True
@asynccontextmanager
async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]:
state.sub_router_startup = True
yield {"sub_router": True}
state.sub_router_shutdown = True
sub_router = APIRouter(lifespan=subrouter_lifespan)
router = APIRouter(lifespan=router_lifespan)
router.include_router(sub_router)
app = FastAPI()
app.include_router(router)
@app.get("/")
def main(request: Request) -> dict[str, str]:
assert request.state.router
assert request.state.sub_router
return {"message": "Hello World"}
assert state.router_startup is False
assert state.sub_router_startup is False
assert state.router_shutdown is False
assert state.sub_router_shutdown is False
with TestClient(app) as client:
assert state.router_startup is True
assert state.sub_router_startup is True
assert state.router_shutdown is False
assert state.sub_router_shutdown is False
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
assert state.router_startup is True
assert state.sub_router_startup is True
assert state.router_shutdown is True
assert state.sub_router_shutdown is True
def test_router_nested_lifespan_state_discard_default_lifespan_child(
state: State,
) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[dict[str, bool], None]:
state.app_startup = True
yield {"app": True}
state.app_shutdown = True
router = APIRouter()
app = FastAPI(lifespan=lifespan)
app.include_router(router)
@app.get("/")
def main(request: Request) -> dict[str, str]:
assert request.state.app
return {"message": "Hello World"}
assert state.app_startup is False
assert state.app_shutdown is False
with TestClient(app) as client:
assert state.app_startup is True
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
assert state.app_shutdown is True
def test_router_nested_lifespan_state_no_lifespans(
state: State,
) -> None:
"""Test that if no lifespans are provided, the app still works and the state is empty."""
router = APIRouter()
app = FastAPI()
app.include_router(router)
@app.get("/")
def main(request: Request) -> dict[str, str]:
return {"message": "Hello World"}
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
def test_router_nested_lifespan_state_overriding_by_parent() -> None:
@asynccontextmanager
async def lifespan(

Loading…
Cancel
Save