diff --git a/fastapi/routing.py b/fastapi/routing.py index 21a1385a27..060c4104fa 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -808,6 +808,33 @@ class APIWebSocketRoute(routing.WebSocketRoute): return match, child_scope +_ACCEPTS_KWARG_CACHE: dict[tuple[type, str], bool] = {} + + +def _accepts_kwarg(cls: type, kwarg: str) -> bool: + """Check if ``cls.__init__`` accepts a given keyword argument. + + Uses a module-level cache to avoid repeated ``inspect.signature`` calls. + Returns ``True`` when the parameter is explicitly declared or when the + signature includes ``**kwargs``. + """ + key = (cls, kwarg) + if key not in _ACCEPTS_KWARG_CACHE: + try: + sig = inspect.signature(cls) + params: Mapping[str, inspect.Parameter] = sig.parameters + except (ValueError, TypeError): + params = {} + if kwarg in params: + _ACCEPTS_KWARG_CACHE[key] = True + else: + _ACCEPTS_KWARG_CACHE[key] = any( + p.kind == inspect.Parameter.VAR_KEYWORD + for p in params.values() + ) + return _ACCEPTS_KWARG_CACHE[key] + + class APIRoute(routing.Route): def __init__( self, @@ -840,6 +867,7 @@ class APIRoute(routing.Route): generate_unique_id_function: Callable[["APIRoute"], str] | DefaultPlaceholder = Default(generate_unique_id), strict_content_type: bool | DefaultPlaceholder = Default(True), + **kwargs: Any, ) -> None: self.path = path self.endpoint = endpoint @@ -1383,9 +1411,7 @@ class APIRouter(routing.Router): current_generate_unique_id = get_value_or_default( generate_unique_id_function, self.generate_unique_id_function ) - route = route_class( - self.prefix + path, - endpoint=endpoint, + route_kwargs: dict[str, Any] = dict( response_model=response_model, status_code=status_code, tags=current_tags, @@ -1410,10 +1436,12 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, - strict_content_type=get_value_or_default( - strict_content_type, self.strict_content_type - ), ) + if _accepts_kwarg(route_class, "strict_content_type"): + route_kwargs["strict_content_type"] = get_value_or_default( + strict_content_type, self.strict_content_type + ) + route = route_class(self.prefix + path, endpoint=endpoint, **route_kwargs) self.routes.append(route) def api_route( diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py index 786c1efc31..43865eea49 100644 --- a/tests/test_custom_route_class.py +++ b/tests/test_custom_route_class.py @@ -1,9 +1,16 @@ import pytest -from fastapi import APIRouter, FastAPI +from enum import Enum +from typing import Any, Callable, Sequence + +from fastapi import APIRouter, FastAPI, params +from fastapi.datastructures import Default, DefaultPlaceholder +from fastapi.responses import JSONResponse, Response from fastapi.routing import APIRoute from fastapi.testclient import TestClient +from fastapi.types import IncEx +from fastapi.utils import generate_unique_id from inline_snapshot import snapshot -from starlette.routing import Route +from starlette.routing import BaseRoute, Route app = FastAPI() @@ -119,3 +126,113 @@ def test_openapi_schema(): }, } ) + + +class LegacyRoute(APIRoute): + """Custom APIRoute that mirrors the pre-strict_content_type signature. + + Regression test for #15503: subclasses with explicit constructors that + don't list ``strict_content_type`` must not break when FastAPI passes it. + """ + + def __init__( + self, + path: str, + endpoint: Callable[..., Any], + *, + response_model: Any = Default(None), + status_code: int | None = None, + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, + summary: str | None = None, + description: str | None = None, + response_description: str = "Successful Response", + responses: dict[int | str, dict[str, Any]] | None = None, + deprecated: bool | None = None, + name: str | None = None, + methods: set[str] | list[str] | None = None, + operation_id: str | None = None, + response_model_include: IncEx | None = None, + response_model_exclude: IncEx | None = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + response_class: type[Response] | DefaultPlaceholder = Default( + JSONResponse + ), + dependency_overrides_provider: Any | None = None, + callbacks: list[BaseRoute] | None = None, + openapi_extra: dict[str, Any] | None = None, + generate_unique_id_function: Callable[[APIRoute], str] + | DefaultPlaceholder = Default(generate_unique_id), + ) -> None: + super().__init__( + path, + endpoint, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary, + description=description, + response_description=response_description, + responses=responses, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + response_class=response_class, + dependency_overrides_provider=dependency_overrides_provider, + callbacks=callbacks, + openapi_extra=openapi_extra, + generate_unique_id_function=generate_unique_id_function, + ) + + +app_legacy = FastAPI() +router_legacy = APIRouter(route_class=LegacyRoute) + + +@router_legacy.get("/items") +def read_items_legacy(): + return {"ok": True} + + +app_legacy.include_router(router_legacy) + +client_legacy = TestClient(app_legacy) + + +def test_custom_route_explicit_constructor_no_strict_content_type(): + """Reproduce #15503: explicit constructor without strict_content_type.""" + response = client_legacy.get("/items") + assert response.status_code == 200 + assert response.json() == {"ok": True} + + +def test_custom_route_explicit_constructor_include_router(): + """LegacyRoute should also survive include_router merging.""" + inner_router = APIRouter(route_class=LegacyRoute) + + @inner_router.get("/inner") + def inner_endpoint(): + return {"inner": True} + + outer_router = APIRouter() + outer_router.include_router(inner_router, prefix="/r") + + app_outer = FastAPI() + app_outer.include_router(outer_router, prefix="/outer") + + resp = TestClient(app_outer).get("/outer/r/inner") + assert resp.status_code == 200 + assert resp.json() == {"inner": True}