Browse Source

Merge c940ab7197 into 460f8d2cc8

pull/14863/merge
ollie-bell 12 hours ago
committed by GitHub
parent
commit
6f98dfcbd0
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 3
      fastapi/dependencies/utils.py
  2. 38
      tests/test_router_events.py

3
fastapi/dependencies/utils.py

@ -362,6 +362,7 @@ def get_dependant(
def add_non_field_param_to_dependency( def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant *, param_name: str, type_annotation: Any, dependant: Dependant
) -> bool | None: ) -> bool | None:
type_annotation = get_origin(type_annotation) or type_annotation
if lenient_issubclass(type_annotation, Request): if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name dependant.request_param_name = param_name
return True return True
@ -486,7 +487,7 @@ def analyze_param(
# Only apply special handling when there's no explicit Depends - if there's a Depends, # Only apply special handling when there's no explicit Depends - if there's a Depends,
# the dependency will be called and its return value used instead of the special injection # the dependency will be called and its return value used instead of the special injection
if depends is None and lenient_issubclass( if depends is None and lenient_issubclass(
type_annotation, get_origin(type_annotation) or type_annotation,
( (
Request, Request,
WebSocket, WebSocket,

38
tests/test_router_events.py

@ -1,10 +1,15 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import TypedDict
import pytest import pytest
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel from pydantic import BaseModel
from starlette import __version__ as STARLETTE_VERSION
from typing_extensions import Self
STARLETTE_MINOR_VERSION_TUPLE = tuple(int(x) for x in STARLETTE_VERSION.split(".")[:2])
class State(BaseModel): class State(BaseModel):
@ -171,6 +176,39 @@ def test_router_nested_lifespan_state(state: State) -> None:
assert state.sub_router_shutdown is True assert state.sub_router_shutdown is True
@pytest.mark.skipif(
STARLETTE_MINOR_VERSION_TUPLE < (0, 52),
reason="Starlette Request with generic type is not supported in Starlette < 0.52.0",
)
def test_router_generic_request_typed_dict_lifespan_state() -> None:
class MyClass:
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass
class MyState(TypedDict):
my_class: MyClass
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[MyState]:
async with MyClass() as my_class:
yield {"my_class": my_class}
app = FastAPI(lifespan=lifespan)
@app.get("/")
def main(request: Request[MyState]) -> dict[str, str]:
assert isinstance(request.state["my_class"], MyClass)
return {"message": "Hello World"}
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
def test_router_nested_lifespan_state_overriding_by_parent() -> None: def test_router_nested_lifespan_state_overriding_by_parent() -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan( async def lifespan(

Loading…
Cancel
Save