From 3a4ac2467594d0ccad92ecfb7f7f10ffa5d1d992 Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Sat, 24 Aug 2024 22:09:52 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Ensure=20that=20`app.include=5Fr?= =?UTF-8?q?outer`=20merges=20nested=20lifespans=20(#9630)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marcelo Trylesinski Co-authored-by: Sebastián Ramírez --- fastapi/routing.py | 27 +++++++- tests/test_router_events.py | 135 +++++++++++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 4 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 2e7959f3d..49f1b6013 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -3,14 +3,16 @@ import dataclasses import email.message import inspect import json -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum, IntEnum from typing import ( Any, + AsyncIterator, Callable, Coroutine, Dict, List, + Mapping, Optional, Sequence, Set, @@ -67,7 +69,7 @@ from starlette.routing import ( websocket_session, ) from starlette.routing import Mount as Mount # noqa -from starlette.types import ASGIApp, Lifespan, Scope +from starlette.types import AppType, ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket from typing_extensions import Annotated, Doc, deprecated @@ -119,6 +121,23 @@ def _prepare_response_content( return res +def _merge_lifespan_context( + original_context: Lifespan[Any], nested_context: Lifespan[Any] +) -> Lifespan[Any]: + @asynccontextmanager + async def merged_lifespan( + app: AppType, + ) -> AsyncIterator[Optional[Mapping[str, Any]]]: + async with original_context(app) as maybe_original_state: + async with nested_context(app) as maybe_nested_state: + if maybe_nested_state is None and maybe_original_state is None: + yield None # old ASGI compatibility + else: + yield {**(maybe_nested_state or {}), **(maybe_original_state or {})} + + return merged_lifespan # type: ignore[return-value] + + async def serialize_response( *, field: Optional[ModelField] = None, @@ -1308,6 +1327,10 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) + self.lifespan_context = _merge_lifespan_context( + self.lifespan_context, + router.lifespan_context, + ) def get( self, diff --git a/tests/test_router_events.py b/tests/test_router_events.py index 1b9de18ae..dd7ff3314 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -1,8 +1,8 @@ from contextlib import asynccontextmanager -from typing import AsyncGenerator, Dict +from typing import AsyncGenerator, Dict, Union import pytest -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, Request from fastapi.testclient import TestClient from pydantic import BaseModel @@ -109,3 +109,134 @@ def test_app_lifespan_state(state: State) -> None: assert response.json() == {"message": "Hello World"} assert state.app_startup is True assert state.app_shutdown is True + + +def test_router_nested_lifespan_state(state: State) -> None: + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]: + state.app_startup = True + yield {"app": True} + state.app_shutdown = True + + @asynccontextmanager + async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]: + state.router_startup = True + yield {"router": True} + state.router_shutdown = True + + @asynccontextmanager + async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]: + state.sub_router_startup = True + yield {"sub_router": True} + state.sub_router_shutdown = True + + sub_router = APIRouter(lifespan=subrouter_lifespan) + + router = APIRouter(lifespan=router_lifespan) + router.include_router(sub_router) + + app = FastAPI(lifespan=lifespan) + app.include_router(router) + + @app.get("/") + def main(request: Request) -> Dict[str, str]: + assert request.state.app + assert request.state.router + assert request.state.sub_router + return {"message": "Hello World"} + + assert state.app_startup is False + assert state.router_startup is False + assert state.sub_router_startup is False + assert state.app_shutdown is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + + with TestClient(app) as client: + assert state.app_startup is True + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.app_shutdown is False + assert state.router_shutdown is False + assert state.sub_router_shutdown is False + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + + assert state.app_startup is True + assert state.router_startup is True + assert state.sub_router_startup is True + assert state.app_shutdown is True + assert state.router_shutdown is True + assert state.sub_router_shutdown is True + + +def test_router_nested_lifespan_state_overriding_by_parent() -> None: + @asynccontextmanager + async def lifespan( + app: FastAPI, + ) -> AsyncGenerator[Dict[str, Union[str, bool]], None]: + yield { + "app_specific": True, + "overridden": "app", + } + + @asynccontextmanager + async def router_lifespan( + app: FastAPI, + ) -> AsyncGenerator[Dict[str, Union[str, bool]], None]: + yield { + "router_specific": True, + "overridden": "router", # should override parent + } + + router = APIRouter(lifespan=router_lifespan) + app = FastAPI(lifespan=lifespan) + app.include_router(router) + + with TestClient(app) as client: + assert client.app_state == { + "app_specific": True, + "router_specific": True, + "overridden": "app", + } + + +def test_merged_no_return_lifespans_return_none() -> None: + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + yield + + @asynccontextmanager + async def router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + yield + + router = APIRouter(lifespan=router_lifespan) + app = FastAPI(lifespan=lifespan) + app.include_router(router) + + with TestClient(app) as client: + assert not client.app_state + + +def test_merged_mixed_state_lifespans() -> None: + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + yield + + @asynccontextmanager + async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]: + yield {"router": True} + + @asynccontextmanager + async def sub_router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + yield + + sub_router = APIRouter(lifespan=sub_router_lifespan) + router = APIRouter(lifespan=router_lifespan) + app = FastAPI(lifespan=lifespan) + router.include_router(sub_router) + app.include_router(router) + + with TestClient(app) as client: + assert client.app_state == {"router": True}