|
|
@ -1,10 +1,15 @@ |
|
|
from collections.abc import AsyncGenerator |
|
|
from collections.abc import AsyncGenerator |
|
|
from contextlib import asynccontextmanager |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from typing import TypedDict |
|
|
|
|
|
|
|
|
import pytest |
|
|
import pytest |
|
|
from fastapi import APIRouter, FastAPI, Request |
|
|
from fastapi import APIRouter, FastAPI, Request |
|
|
from fastapi.testclient import TestClient |
|
|
from fastapi.testclient import TestClient |
|
|
from pydantic import BaseModel |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from starlette import __version__ as STARLETTE_VERSION |
|
|
|
|
|
from typing_extensions import Self |
|
|
|
|
|
|
|
|
|
|
|
STARLETTE_MINOR_VERSION_TUPLE = tuple(int(x) for x in STARLETTE_VERSION.split(".")[:2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class State(BaseModel): |
|
|
class State(BaseModel): |
|
|
@ -171,6 +176,39 @@ def test_router_nested_lifespan_state(state: State) -> None: |
|
|
assert state.sub_router_shutdown is True |
|
|
assert state.sub_router_shutdown is True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
|
|
|
STARLETTE_MINOR_VERSION_TUPLE < (0, 52), |
|
|
|
|
|
reason="Starlette Request with generic type is not supported in Starlette < 0.52.0", |
|
|
|
|
|
) |
|
|
|
|
|
def test_router_generic_request_typed_dict_lifespan_state() -> None: |
|
|
|
|
|
class MyClass: |
|
|
|
|
|
async def __aenter__(self) -> Self: |
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback) -> None: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class MyState(TypedDict): |
|
|
|
|
|
my_class: MyClass |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
|
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[MyState]: |
|
|
|
|
|
async with MyClass() as my_class: |
|
|
|
|
|
yield {"my_class": my_class} |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
|
|
|
def main(request: Request[MyState]) -> dict[str, str]: |
|
|
|
|
|
assert isinstance(request.state["my_class"], MyClass) |
|
|
|
|
|
return {"message": "Hello World"} |
|
|
|
|
|
|
|
|
|
|
|
with TestClient(app) as client: |
|
|
|
|
|
response = client.get("/") |
|
|
|
|
|
assert response.status_code == 200, response.text |
|
|
|
|
|
assert response.json() == {"message": "Hello World"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_router_nested_lifespan_state_overriding_by_parent() -> None: |
|
|
def test_router_nested_lifespan_state_overriding_by_parent() -> None: |
|
|
@asynccontextmanager |
|
|
@asynccontextmanager |
|
|
async def lifespan( |
|
|
async def lifespan( |
|
|
|