Browse Source

fix(routing): conditionally pass strict_content_type to custom APIRoute classes (#15503)

Custom APIRoute subclasses that list the previous __init__ signature
(without strict_content_type) broke when strict_content_type was
introduced. This fix inspects the subclass __init__ signature and
only passes strict_content_type when it's explicitly accepted or
when **kwargs is present, restoring backward compatibility.

Includes regression tests for both the broken and **kwargs paths.
pull/15508/head
Alex 4 weeks ago
parent
commit
0cf10d79a7
  1. 69
      fastapi/routing.py
  2. 120
      tests/test_custom_route_class.py

69
fastapi/routing.py

@ -1383,37 +1383,46 @@ class APIRouter(routing.Router):
current_generate_unique_id = get_value_or_default(
generate_unique_id_function, self.generate_unique_id_function
)
route = route_class(
self.prefix + path,
endpoint=endpoint,
response_model=response_model,
status_code=status_code,
tags=current_tags,
dependencies=current_dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=combined_responses,
deprecated=deprecated or self.deprecated,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema and self.include_in_schema,
response_class=current_response_class,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
callbacks=current_callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id,
strict_content_type=get_value_or_default(
route_init_sig = inspect.signature(route_class.__init__)
route_init_params = route_init_sig.parameters
route_kwargs: dict[str, Any] = {
"endpoint": endpoint,
"response_model": response_model,
"status_code": status_code,
"tags": current_tags,
"dependencies": current_dependencies,
"summary": summary,
"description": description,
"response_description": response_description,
"responses": combined_responses,
"deprecated": deprecated or self.deprecated,
"methods": methods,
"operation_id": operation_id,
"response_model_include": response_model_include,
"response_model_exclude": response_model_exclude,
"response_model_by_alias": response_model_by_alias,
"response_model_exclude_unset": response_model_exclude_unset,
"response_model_exclude_defaults": response_model_exclude_defaults,
"response_model_exclude_none": response_model_exclude_none,
"include_in_schema": include_in_schema and self.include_in_schema,
"response_class": current_response_class,
"name": name,
"dependency_overrides_provider": self.dependency_overrides_provider,
"callbacks": current_callbacks,
"openapi_extra": openapi_extra,
"generate_unique_id_function": current_generate_unique_id,
}
if (
"strict_content_type" in route_init_params
or any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in route_init_params.values()
)
):
route_kwargs["strict_content_type"] = get_value_or_default(
strict_content_type, self.strict_content_type
),
)
)
route = route_class(self.prefix + path, **route_kwargs)
self.routes.append(route)
def api_route(

120
tests/test_custom_route_class.py

@ -1,9 +1,17 @@
from collections.abc import Sequence
from enum import Enum
from typing import Any, Callable
import pytest
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter, FastAPI, params
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.responses import JSONResponse, Response
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from fastapi.types import IncEx
from fastapi.utils import generate_unique_id
from inline_snapshot import snapshot
from starlette.routing import Route
from starlette.routing import BaseRoute, Route
app = FastAPI()
@ -119,3 +127,111 @@ def test_openapi_schema():
},
}
)
def test_custom_api_route_without_strict_content_type():
"""
Regression test for #15503:
Custom APIRoute classes that explicitly list the previous
__init__ parameters (without strict_content_type) should still
work when registered via APIRouter.
"""
class LegacyRoute(APIRoute):
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
response_model: Any = Default(None),
status_code: int | None = None,
tags: list[str | Enum] | None = None,
dependencies: Sequence[params.Depends] | None = None,
summary: str | None = None,
description: str | None = None,
response_description: str = "Successful Response",
responses: dict[int | str, dict[str, Any]] | None = None,
deprecated: bool | None = None,
name: str | None = None,
methods: set[str] | list[str] | None = None,
operation_id: str | None = None,
response_model_include: IncEx | None = None,
response_model_exclude: IncEx | None = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: type[Response] | DefaultPlaceholder = Default(JSONResponse),
dependency_overrides_provider: Any | None = None,
callbacks: list[BaseRoute] | None = None,
openapi_extra: dict[str, Any] | None = None,
generate_unique_id_function: Callable[[APIRoute], str]
| DefaultPlaceholder = Default(generate_unique_id),
) -> None:
super().__init__(
path,
endpoint,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
name=name,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
dependency_overrides_provider=dependency_overrides_provider,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
app_legacy = FastAPI()
router_legacy = APIRouter(route_class=LegacyRoute)
@router_legacy.get("/items")
def read_items():
return {"ok": True}
# This must not raise TypeError
app_legacy.include_router(router_legacy)
client_legacy = TestClient(app_legacy)
response = client_legacy.get("/items")
assert response.status_code == 200
assert response.json() == {"ok": True}
def test_custom_api_route_with_kwargs():
"""
Custom APIRoute classes using **kwargs should still receive
strict_content_type and work correctly.
"""
class KwargsRoute(APIRoute):
def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None:
super().__init__(path, endpoint, **kwargs)
app_kwargs = FastAPI()
router_kwargs = APIRouter(route_class=KwargsRoute)
@router_kwargs.get("/items")
def read_items_kwargs():
return {"ok": True}
app_kwargs.include_router(router_kwargs)
client_kwargs = TestClient(app_kwargs)
response = client_kwargs.get("/items")
assert response.status_code == 200
assert response.json() == {"ok": True}

Loading…
Cancel
Save