diff --git a/fastapi/routing.py b/fastapi/routing.py index 21a1385a27..5f9eda87c6 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1383,37 +1383,46 @@ class APIRouter(routing.Router): current_generate_unique_id = get_value_or_default( generate_unique_id_function, self.generate_unique_id_function ) - route = route_class( - self.prefix + path, - endpoint=endpoint, - response_model=response_model, - status_code=status_code, - tags=current_tags, - dependencies=current_dependencies, - summary=summary, - description=description, - response_description=response_description, - responses=combined_responses, - deprecated=deprecated or self.deprecated, - methods=methods, - operation_id=operation_id, - response_model_include=response_model_include, - response_model_exclude=response_model_exclude, - response_model_by_alias=response_model_by_alias, - response_model_exclude_unset=response_model_exclude_unset, - response_model_exclude_defaults=response_model_exclude_defaults, - response_model_exclude_none=response_model_exclude_none, - include_in_schema=include_in_schema and self.include_in_schema, - response_class=current_response_class, - name=name, - dependency_overrides_provider=self.dependency_overrides_provider, - callbacks=current_callbacks, - openapi_extra=openapi_extra, - generate_unique_id_function=current_generate_unique_id, - strict_content_type=get_value_or_default( + route_init_sig = inspect.signature(route_class.__init__) + route_init_params = route_init_sig.parameters + route_kwargs: dict[str, Any] = { + "endpoint": endpoint, + "response_model": response_model, + "status_code": status_code, + "tags": current_tags, + "dependencies": current_dependencies, + "summary": summary, + "description": description, + "response_description": response_description, + "responses": combined_responses, + "deprecated": deprecated or self.deprecated, + "methods": methods, + "operation_id": operation_id, + "response_model_include": response_model_include, + "response_model_exclude": response_model_exclude, + "response_model_by_alias": response_model_by_alias, + "response_model_exclude_unset": response_model_exclude_unset, + "response_model_exclude_defaults": response_model_exclude_defaults, + "response_model_exclude_none": response_model_exclude_none, + "include_in_schema": include_in_schema and self.include_in_schema, + "response_class": current_response_class, + "name": name, + "dependency_overrides_provider": self.dependency_overrides_provider, + "callbacks": current_callbacks, + "openapi_extra": openapi_extra, + "generate_unique_id_function": current_generate_unique_id, + } + if ( + "strict_content_type" in route_init_params + or any( + p.kind == inspect.Parameter.VAR_KEYWORD + for p in route_init_params.values() + ) + ): + route_kwargs["strict_content_type"] = get_value_or_default( strict_content_type, self.strict_content_type - ), - ) + ) + route = route_class(self.prefix + path, **route_kwargs) self.routes.append(route) def api_route( diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py index 786c1efc31..49ce7768a4 100644 --- a/tests/test_custom_route_class.py +++ b/tests/test_custom_route_class.py @@ -1,9 +1,17 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Any, Callable + import pytest -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, FastAPI, params +from fastapi.datastructures import Default, DefaultPlaceholder +from fastapi.responses import JSONResponse, Response from fastapi.routing import APIRoute from fastapi.testclient import TestClient +from fastapi.types import IncEx +from fastapi.utils import generate_unique_id from inline_snapshot import snapshot -from starlette.routing import Route +from starlette.routing import BaseRoute, Route app = FastAPI() @@ -119,3 +127,111 @@ def test_openapi_schema(): }, } ) + +def test_custom_api_route_without_strict_content_type(): + """ + Regression test for #15503: + Custom APIRoute classes that explicitly list the previous + __init__ parameters (without strict_content_type) should still + work when registered via APIRouter. + """ + + class LegacyRoute(APIRoute): + def __init__( + self, + path: str, + endpoint: Callable[..., Any], + *, + response_model: Any = Default(None), + status_code: int | None = None, + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, + summary: str | None = None, + description: str | None = None, + response_description: str = "Successful Response", + responses: dict[int | str, dict[str, Any]] | None = None, + deprecated: bool | None = None, + name: str | None = None, + methods: set[str] | list[str] | None = None, + operation_id: str | None = None, + response_model_include: IncEx | None = None, + response_model_exclude: IncEx | None = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + response_class: type[Response] | DefaultPlaceholder = Default(JSONResponse), + dependency_overrides_provider: Any | None = None, + callbacks: list[BaseRoute] | None = None, + openapi_extra: dict[str, Any] | None = None, + generate_unique_id_function: Callable[[APIRoute], str] + | DefaultPlaceholder = Default(generate_unique_id), + ) -> None: + super().__init__( + path, + endpoint, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary, + description=description, + response_description=response_description, + responses=responses, + deprecated=deprecated, + name=name, + methods=methods, + operation_id=operation_id, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + response_class=response_class, + dependency_overrides_provider=dependency_overrides_provider, + callbacks=callbacks, + openapi_extra=openapi_extra, + generate_unique_id_function=generate_unique_id_function, + ) + + app_legacy = FastAPI() + router_legacy = APIRouter(route_class=LegacyRoute) + + @router_legacy.get("/items") + def read_items(): + return {"ok": True} + + # This must not raise TypeError + app_legacy.include_router(router_legacy) + + client_legacy = TestClient(app_legacy) + response = client_legacy.get("/items") + assert response.status_code == 200 + assert response.json() == {"ok": True} + +def test_custom_api_route_with_kwargs(): + """ + Custom APIRoute classes using **kwargs should still receive + strict_content_type and work correctly. + """ + + class KwargsRoute(APIRoute): + def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None: + super().__init__(path, endpoint, **kwargs) + + app_kwargs = FastAPI() + router_kwargs = APIRouter(route_class=KwargsRoute) + + @router_kwargs.get("/items") + def read_items_kwargs(): + return {"ok": True} + + app_kwargs.include_router(router_kwargs) + + client_kwargs = TestClient(app_kwargs) + response = client_kwargs.get("/items") + assert response.status_code == 200 + assert response.json() == {"ok": True}