diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..1f2b7ee30 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -212,11 +212,14 @@ def get_flat_dependant( def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: if not fields: return fields - first_field = fields[0] - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - return fields_to_extract - return fields + + fields_to_extract = [] + for f in fields: + if lenient_issubclass(f.type_, BaseModel): + fields_to_extract.extend(get_cached_model_fields(f.type_)) + else: + fields_to_extract.append(f) + return fields_to_extract def get_flat_params(dependant: Dependant) -> List[ModelField]: @@ -747,30 +750,28 @@ def request_params_to_args( if not fields: return values, errors - first_field = fields[0] - fields_to_extract = fields - single_not_embedded_field = False default_convert_underscores = True - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True - # If headers are in a Pydantic model, the way to disable convert_underscores - # would be with Header(convert_underscores=False) at the Pydantic model level - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) params_to_process: Dict[str, Any] = {} + fields_to_extract = [ + (field, cached_field) + for field in fields + if lenient_issubclass(field.type_, BaseModel) + for cached_field in get_cached_model_fields(field.type_) + ] + processed_keys = set() - for field in fields_to_extract: + for parent_field, field in fields_to_extract: alias = None if isinstance(received_params, Headers): # Handle fields extracted from a Pydantic Model for a header, each field # doesn't have a FieldInfo of type Header with the default convert_underscores=True convert_underscores = getattr( - field.field_info, "convert_underscores", default_convert_underscores + parent_field.field_info, + "convert_underscores", + default_convert_underscores, ) if convert_underscores: alias = ( @@ -788,27 +789,24 @@ def request_params_to_args( if key not in processed_keys: params_to_process[key] = value - if single_not_embedded_field: - field_info = first_field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) - loc: Tuple[str, ...] = (field_info.in_.value,) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=params_to_process, values=values, loc=loc - ) - return {first_field.name: v_}, errors_ - for field in fields: - value = _get_multidict_value(field, received_params) field_info = field.field_info assert isinstance(field_info, params.Param), ( "Params must be subclasses of Param" ) - loc = (field_info.in_.value, field.alias) - v_, errors_ = _validate_value_with_model_field( - field=field, value=value, values=values, loc=loc - ) + + if lenient_issubclass(field.type_, BaseModel): + loc: Tuple[str, ...] = (field_info.in_.value,) + v_, errors_ = _validate_value_with_model_field( + field=field, value=params_to_process, values=values, loc=loc + ) + else: + value = _get_multidict_value(field, received_params) + loc = (field_info.in_.value, field.alias) + v_, errors_ = _validate_value_with_model_field( + field=field, value=value, values=values, loc=loc + ) + if errors_: errors.extend(errors_) else: diff --git a/tests/test_multiple_params_models.py b/tests/test_multiple_params_models.py new file mode 100644 index 000000000..f8d7c8b8d --- /dev/null +++ b/tests/test_multiple_params_models.py @@ -0,0 +1,137 @@ +from typing import Any, Callable + +import pytest +from fastapi import APIRouter, Cookie, FastAPI, Header, Query, status +from fastapi.testclient import TestClient +from pydantic import BaseModel +from typing_extensions import Annotated + +app = FastAPI() +client = TestClient(app) + + +class NameModel(BaseModel): + name: str + + +class AgeModel(BaseModel): + age: int + + +def add_routes( + in_: Callable[..., Any], + prefix: str, +) -> None: + router = APIRouter(prefix=prefix) + + @router.get("/models") + async def route_models( + name_model: Annotated[NameModel, in_()], + age_model: Annotated[AgeModel, in_()], + ): + return { + "name": name_model.name, + "age": age_model.age, + } + + @router.get("/mixed") + async def route_mixed( + name_model: Annotated[NameModel, in_()], + age: Annotated[int, in_()], + ): + return { + "name": name_model.name, + "age": age, + } + + app.include_router(router) + + +add_routes(Query, "/query") +add_routes(Header, "/header") +add_routes(Cookie, "/cookie") + + +@pytest.mark.parametrize( + ("in_", "prefix", "call_arg"), + [ + (Query, "/query", "params"), + (Header, "/header", "headers"), + (Cookie, "/cookie", "cookies"), + ], + ids=[ + "query", + "header", + "cookie", + ], +) +@pytest.mark.parametrize( + "type_", + [ + "models", + "mixed", + ], + ids=[ + "models", + "mixed", + ], +) +def test_multiple_params(in_, prefix, call_arg, type_): + params = {"name": "John", "age": "42"} + kwargs = {} + + if call_arg == "cookies": + client.cookies = params + else: + kwargs[call_arg] = params + + response = client.get(f"{prefix}/{type_}", **kwargs) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"name": "John", "age": 42} + + +@pytest.mark.parametrize( + ("prefix", "in_"), + [ + ("/query", "query"), + ("/header", "header"), + ("/cookie", "cookie"), + ], + ids=[ + "query", + "header", + "cookie", + ], +) +@pytest.mark.parametrize( + "type_", + [ + "models", + "mixed", + ], + ids=[ + "models", + "mixed", + ], +) +def test_openapi_schema(prefix, in_, type_): + response = client.get("/openapi.json") + + assert response.status_code == status.HTTP_200_OK + + schema = response.json() + assert schema["paths"][f"{prefix}/{type_}"]["get"]["parameters"] == [ + { + "required": True, + "in": in_, + "name": "name", + "schema": {"title": "Name", "type": "string"}, + }, + { + "required": True, + "in": in_, + "name": "age", + "schema": {"title": "Age", "type": "integer"}, + }, + ]