From 5bdbcc3c895909273ee4e44355951e22d23c9685 Mon Sep 17 00:00:00 2001 From: Lancetnik Date: Tue, 6 Jun 2023 21:13:43 +0300 Subject: [PATCH] nest routers` lifespans --- fastapi/routing.py | 25 ++++++++++++++-- tests/test_router_events.py | 57 +++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 06c71bffa..29b29977b 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -3,9 +3,10 @@ import dataclasses import email.message import inspect import json -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum, IntEnum from typing import ( + AsyncIterator, Any, Callable, Coroutine, @@ -17,6 +18,7 @@ from typing import ( Tuple, Type, Union, + Mapping, ) from fastapi import params @@ -57,7 +59,7 @@ from starlette.routing import ( websocket_session, ) from starlette.status import WS_1008_POLICY_VIOLATION -from starlette.types import ASGIApp, Lifespan, Scope +from starlette.types import ASGIApp, Lifespan, Scope, AppType from starlette.websockets import WebSocket @@ -107,6 +109,21 @@ def _prepare_response_content( return res +def _merge_lifespan_context( + original_context: Lifespan[Any], nested_context: Lifespan[Any] +) -> Lifespan[Any]: + @asynccontextmanager + async def merged_lifespan(app: AppType) -> AsyncIterator[Mapping[str, Any]]: + async with original_context(app) as maybe_self_context: + async with nested_context(app) as maybe_nested_context: + context = maybe_self_context or {} + if maybe_nested_context: + context.update(maybe_nested_context) + yield context + + return merged_lifespan + + async def serialize_response( *, field: Optional[ModelField] = None, @@ -830,6 +847,10 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) + self.lifespan_context = _merge_lifespan_context( + self.lifespan_context, + router.lifespan_context, + ) def get( self, diff --git a/tests/test_router_events.py b/tests/test_router_events.py index ba6b76382..440b4b469 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -106,3 +106,60 @@ def test_app_lifespan_state(state: State) -> None: assert response.json() == {"message": "Hello World"} assert state.app_startup is True assert state.app_shutdown is True + + +def test_router_nested_lifespan_state(state: State) -> None: + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + state.app_startup = True + yield + state.app_shutdown = True + + @asynccontextmanager + async def router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + state.router_startup = True + yield + state.router_shutdown = True + + @asynccontextmanager + async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + state.sub_router_startup = True + yield + state.sub_router_shutdown = True + + sub_router = APIRouter(lifespan=subrouter_lifespan) + + router = APIRouter(lifespan=router_lifespan) + router.include_router(sub_router) + + app = FastAPI(lifespan=lifespan) + app.include_router(router) + + @app.get("/") + def main() -> Dict[str, str]: + return {"message": "Hello World"} + + assert state.app_startup is False + assert state.router_startup is False + assert state.sub_router_startup is False + assert state.app_shutdown is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + + with TestClient(app) as client: + assert state.app_startup is True + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.app_shutdown is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + + assert state.app_startup is True + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.app_shutdown is True + assert state.router_shutdown is True + assert state.sub_router_shutdown is True