From 86d9bb197aa61838fc1196674f8c630aae72741f Mon Sep 17 00:00:00 2001 From: JSCU-CNI <121175071+JSCU-CNI@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:52:33 +0100 Subject: [PATCH] Add tests and fix handling of multiple fields --- fastapi/dependencies/utils.py | 58 +++++++------ tests/test_multiple_parameter_models.py | 106 ++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 30 deletions(-) create mode 100644 tests/test_multiple_parameter_models.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 83a7a44c5..b3821e814 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -750,30 +750,34 @@ def request_params_to_args( if not fields: return values, errors - first_field = fields[0] - fields_to_extract = fields - single_not_embedded_field = False - default_convert_underscores = True - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True - # If headers are in a Pydantic model, the way to disable convert_underscores - # would be with Header(convert_underscores=False) at the Pydantic model level - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) + fields_to_extract = [] + for field in fields: + if lenient_issubclass(field.type_, BaseModel): + fields_to_extract.extend( + [ + (model_field, field) + for model_field in get_cached_model_fields(field.type_) + ] + ) + else: + fields_to_extract.append((field, field)) params_to_process: Dict[str, Any] = {} processed_keys = set() - for field in fields_to_extract: + for field, base_field in fields_to_extract: alias = None if isinstance(received_params, Headers): # Handle fields extracted from a Pydantic Model for a header, each field # doesn't have a FieldInfo of type Header with the default convert_underscores=True + + # If headers are in a Pydantic model, the way to disable convert_underscores + # would be with Header(convert_underscores=False) at the Pydantic model level convert_underscores = getattr( - field.field_info, "convert_underscores", default_convert_underscores + field.field_info, + "convert_underscores", + getattr(base_field.field_info, "convert_underscores", True) ) if convert_underscores: alias = ( @@ -791,24 +795,18 @@ def request_params_to_args( if key not in processed_keys: params_to_process[key] = value - if single_not_embedded_field: - field_info = first_field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) - loc: Tuple[str, ...] = (field_info.in_.value,) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=params_to_process, values=values, loc=loc - ) - return {first_field.name: v_}, errors_ - for field in fields: - value = _get_multidict_value(field, received_params) field_info = field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) - loc = (field_info.in_.value, field.alias) + assert isinstance( + field_info, params.Param + ), "Params must be subclasses of Param" + + if lenient_issubclass(field.type_, BaseModel): + value = params_to_process + loc: tuple[str, ...] = (field_info.in_.value,) + else: + value = _get_multidict_value(field, received_params) + loc = (field_info.in_.value, field.alias) v_, errors_ = _validate_value_with_model_field( field=field, value=value, values=values, loc=loc ) diff --git a/tests/test_multiple_parameter_models.py b/tests/test_multiple_parameter_models.py new file mode 100644 index 000000000..e84efca68 --- /dev/null +++ b/tests/test_multiple_parameter_models.py @@ -0,0 +1,106 @@ +import pytest +from fastapi import Cookie, Depends, FastAPI, Header, Query +from fastapi.testclient import TestClient +from pydantic import BaseModel, Field + +app = FastAPI() + + +class Model(BaseModel): + field1: int = Field(0) + + +class Model2(BaseModel): + field2: int = Field(0) + + +for param in (Query, Header, Cookie): + + def dependency(field2: int = param(0)): + return field2 + + @app.get(f"/{param.__name__.lower()}-depends/") + async def with_depends(model1: Model = param(), dependency=Depends(dependency)): + return {"field1": model1.field1, "field2": dependency} + + @app.get(f"/{param.__name__.lower()}-argument/") + async def with_model_and_argument(model1: Model = param(), field2: int = param(0)): + return {"field1": model1.field1, "field2": field2} + + @app.get(f"/{param.__name__.lower()}-models/") + async def with_models(model1: Model = param(), model2: Model2 = param()): + return {"field1": model1.field1, "field2": model2.field2} + + @app.get(f"/{param.__name__.lower()}-arguments/") + async def with_argument(field1: int = param(0), field2: int = param(0)): + return {"field1": field1, "field2": field2} + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "path", + ["/query-depends/", "/query-arguments/", "/query-argument/", "/query-models/"], +) +def test_query_depends(path): + response = client.get(path, params={"field1": 0, "field2": 1}) + assert response.status_code == 200 + assert response.json() == {"field1": 0, "field2": 1} + + +@pytest.mark.parametrize( + "path", + ["/header-depends/", "/header-arguments/", "/header-argument/", "/header-models/"], +) +def test_header_depends(path): + response = client.get(path, headers={"field1": "0", "field2": "1"}) + assert response.status_code == 200 + assert response.json() == {"field1": 0, "field2": 1} + + +@pytest.mark.parametrize( + "path", + ["/cookie-depends/", "/cookie-arguments/", "/cookie-argument/", "/cookie-models/"], +) +def test_cookie_depends(path): + client.cookies = {"field1": "0", "field2": "1"} + response = client.get(path) + assert response.status_code == 200 + assert response.json() == {"field1": 0, "field2": 1} + + +@pytest.mark.parametrize( + ("path", "in_"), + [ + ("/query-depends/", "query"), + ("/query-arguments/", "query"), + ("/query-argument/", "query"), + ("/query-models/", "query"), + ("/header-depends/", "header"), + ("/header-arguments/", "header"), + ("/header-argument/", "header"), + ("/header-models/", "header"), + ("/cookie-depends/", "cookie"), + ("/cookie-arguments/", "cookie"), + ("/cookie-argument/", "cookie"), + ("/cookie-models/", "cookie"), + ], +) +def test_parameters_openapi_schema(path, in_): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()["paths"][path]["get"]["parameters"] == [ + { + "name": "field1", + "in": in_, + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field1"}, + }, + { + "name": "field2", + "in": in_, + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field2"}, + }, + ]