Browse Source

fix: support custom APIRoute subclasses with explicit constructors

Custom APIRoute subclasses that define an explicit __init__ without
accepting `strict_content_type` would raise TypeError after the
addition of that parameter. Fixed by checking whether the route class
accepts the keyword argument before passing it.

Fixes #15503
pull/15546/head
theshyxin 3 weeks ago
parent
commit
832a452bec
  1. 40
      fastapi/routing.py
  2. 121
      tests/test_custom_route_class.py

40
fastapi/routing.py

@ -808,6 +808,33 @@ class APIWebSocketRoute(routing.WebSocketRoute):
return match, child_scope
_ACCEPTS_KWARG_CACHE: dict[tuple[type, str], bool] = {}
def _accepts_kwarg(cls: type, kwarg: str) -> bool:
"""Check if ``cls.__init__`` accepts a given keyword argument.
Uses a module-level cache to avoid repeated ``inspect.signature`` calls.
Returns ``True`` when the parameter is explicitly declared or when the
signature includes ``**kwargs``.
"""
key = (cls, kwarg)
if key not in _ACCEPTS_KWARG_CACHE:
try:
sig = inspect.signature(cls)
params: Mapping[str, inspect.Parameter] = sig.parameters
except (ValueError, TypeError):
params = {}
if kwarg in params:
_ACCEPTS_KWARG_CACHE[key] = True
else:
_ACCEPTS_KWARG_CACHE[key] = any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in params.values()
)
return _ACCEPTS_KWARG_CACHE[key]
class APIRoute(routing.Route):
def __init__(
self,
@ -840,6 +867,7 @@ class APIRoute(routing.Route):
generate_unique_id_function: Callable[["APIRoute"], str]
| DefaultPlaceholder = Default(generate_unique_id),
strict_content_type: bool | DefaultPlaceholder = Default(True),
**kwargs: Any,
) -> None:
self.path = path
self.endpoint = endpoint
@ -1383,9 +1411,7 @@ class APIRouter(routing.Router):
current_generate_unique_id = get_value_or_default(
generate_unique_id_function, self.generate_unique_id_function
)
route = route_class(
self.prefix + path,
endpoint=endpoint,
route_kwargs: dict[str, Any] = dict(
response_model=response_model,
status_code=status_code,
tags=current_tags,
@ -1410,10 +1436,12 @@ class APIRouter(routing.Router):
callbacks=current_callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id,
strict_content_type=get_value_or_default(
strict_content_type, self.strict_content_type
),
)
if _accepts_kwarg(route_class, "strict_content_type"):
route_kwargs["strict_content_type"] = get_value_or_default(
strict_content_type, self.strict_content_type
)
route = route_class(self.prefix + path, endpoint=endpoint, **route_kwargs)
self.routes.append(route)
def api_route(

121
tests/test_custom_route_class.py

@ -1,9 +1,16 @@
import pytest
from fastapi import APIRouter, FastAPI
from enum import Enum
from typing import Any, Callable, Sequence
from fastapi import APIRouter, FastAPI, params
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.responses import JSONResponse, Response
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from fastapi.types import IncEx
from fastapi.utils import generate_unique_id
from inline_snapshot import snapshot
from starlette.routing import Route
from starlette.routing import BaseRoute, Route
app = FastAPI()
@ -119,3 +126,113 @@ def test_openapi_schema():
},
}
)
class LegacyRoute(APIRoute):
"""Custom APIRoute that mirrors the pre-strict_content_type signature.
Regression test for #15503: subclasses with explicit constructors that
don't list ``strict_content_type`` must not break when FastAPI passes it.
"""
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
response_model: Any = Default(None),
status_code: int | None = None,
tags: list[str | Enum] | None = None,
dependencies: Sequence[params.Depends] | None = None,
summary: str | None = None,
description: str | None = None,
response_description: str = "Successful Response",
responses: dict[int | str, dict[str, Any]] | None = None,
deprecated: bool | None = None,
name: str | None = None,
methods: set[str] | list[str] | None = None,
operation_id: str | None = None,
response_model_include: IncEx | None = None,
response_model_exclude: IncEx | None = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: type[Response] | DefaultPlaceholder = Default(
JSONResponse
),
dependency_overrides_provider: Any | None = None,
callbacks: list[BaseRoute] | None = None,
openapi_extra: dict[str, Any] | None = None,
generate_unique_id_function: Callable[[APIRoute], str]
| DefaultPlaceholder = Default(generate_unique_id),
) -> None:
super().__init__(
path,
endpoint,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
name=name,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
dependency_overrides_provider=dependency_overrides_provider,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
app_legacy = FastAPI()
router_legacy = APIRouter(route_class=LegacyRoute)
@router_legacy.get("/items")
def read_items_legacy():
return {"ok": True}
app_legacy.include_router(router_legacy)
client_legacy = TestClient(app_legacy)
def test_custom_route_explicit_constructor_no_strict_content_type():
"""Reproduce #15503: explicit constructor without strict_content_type."""
response = client_legacy.get("/items")
assert response.status_code == 200
assert response.json() == {"ok": True}
def test_custom_route_explicit_constructor_include_router():
"""LegacyRoute should also survive include_router merging."""
inner_router = APIRouter(route_class=LegacyRoute)
@inner_router.get("/inner")
def inner_endpoint():
return {"inner": True}
outer_router = APIRouter()
outer_router.include_router(inner_router, prefix="/r")
app_outer = FastAPI()
app_outer.include_router(outer_router, prefix="/outer")
resp = TestClient(app_outer).get("/outer/r/inner")
assert resp.status_code == 200
assert resp.json() == {"inner": True}

Loading…
Cancel
Save