diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..635950877 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -212,11 +212,28 @@ def get_flat_dependant( def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: if not fields: return fields - first_field = fields[0] - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - return fields_to_extract - return fields + + return [field for field, _ in _get_flat_fields_from_params_with_origin(fields)] + + +def _get_flat_fields_from_params_with_origin( + fields: Sequence[ModelField], +) -> Sequence[Tuple[ModelField, ModelField]]: + """Same as :func:`_get_flat_fields_from_params`, but returns the result + as tuples ``(flat_field, origin_field)``. + """ + result = [] + for field in fields: + if lenient_issubclass(field.type_, BaseModel): + result.extend( + [ + (model_field, field) + for model_field in get_cached_model_fields(field.type_) + ] + ) + else: + result.append((field, field)) + return result def get_flat_params(dependant: Dependant) -> List[ModelField]: @@ -747,30 +764,23 @@ def request_params_to_args( if not fields: return values, errors - first_field = fields[0] - fields_to_extract = fields - single_not_embedded_field = False - default_convert_underscores = True - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True - # If headers are in a Pydantic model, the way to disable convert_underscores - # would be with Header(convert_underscores=False) at the Pydantic model level - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) - + fields_to_extract = _get_flat_fields_from_params_with_origin(fields) params_to_process: Dict[str, Any] = {} processed_keys = set() - for field in fields_to_extract: + for field, base_field in fields_to_extract: alias = None if isinstance(received_params, Headers): # Handle fields extracted from a Pydantic Model for a header, each field # doesn't have a FieldInfo of type Header with the default convert_underscores=True + + # If headers are in a Pydantic model, the way to disable convert_underscores + # would be with Header(convert_underscores=False) at the Pydantic model level convert_underscores = getattr( - field.field_info, "convert_underscores", default_convert_underscores + field.field_info, + "convert_underscores", + getattr(base_field.field_info, "convert_underscores", True), ) if convert_underscores: alias = ( @@ -788,24 +798,18 @@ def request_params_to_args( if key not in processed_keys: params_to_process[key] = value - if single_not_embedded_field: - field_info = first_field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) - loc: Tuple[str, ...] = (field_info.in_.value,) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=params_to_process, values=values, loc=loc - ) - return {first_field.name: v_}, errors_ - for field in fields: - value = _get_multidict_value(field, received_params) field_info = field.field_info assert isinstance(field_info, params.Param), ( "Params must be subclasses of Param" ) - loc = (field_info.in_.value, field.alias) + + if lenient_issubclass(field.type_, BaseModel): + value = params_to_process + loc: tuple[str, ...] = (field_info.in_.value,) + else: + value = _get_multidict_value(field, received_params) + loc = (field_info.in_.value, field.alias) v_, errors_ = _validate_value_with_model_field( field=field, value=value, values=values, loc=loc ) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 808646cc2..bbce37b9e 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -17,7 +17,7 @@ from fastapi._compat import ( from fastapi.datastructures import DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( - _get_flat_fields_from_params, + _get_flat_fields_from_params_with_origin, get_flat_dependant, get_flat_params, ) @@ -32,7 +32,6 @@ from fastapi.utils import ( generate_operation_id_for_path, is_body_allowed_for_status_code, ) -from pydantic import BaseModel from starlette.responses import JSONResponse from starlette.routing import BaseRoute from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY @@ -104,25 +103,22 @@ def _get_openapi_operation_parameters( ) -> List[Dict[str, Any]]: parameters = [] flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - path_params = _get_flat_fields_from_params(flat_dependant.path_params) - query_params = _get_flat_fields_from_params(flat_dependant.query_params) - header_params = _get_flat_fields_from_params(flat_dependant.header_params) - cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params) + path_params = _get_flat_fields_from_params_with_origin(flat_dependant.path_params) + query_params = _get_flat_fields_from_params_with_origin(flat_dependant.query_params) + header_params = _get_flat_fields_from_params_with_origin( + flat_dependant.header_params + ) + cookie_params = _get_flat_fields_from_params_with_origin( + flat_dependant.cookie_params + ) parameter_groups = [ (ParamTypes.path, path_params), (ParamTypes.query, query_params), (ParamTypes.header, header_params), (ParamTypes.cookie, cookie_params), ] - default_convert_underscores = True - if len(flat_dependant.header_params) == 1: - first_field = flat_dependant.header_params[0] - if lenient_issubclass(first_field.type_, BaseModel): - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) for param_type, param_group in parameter_groups: - for param in param_group: + for param, base_field in param_group: field_info = param.field_info # field_info = cast(Param, field_info) if not getattr(field_info, "include_in_schema", True): @@ -138,7 +134,7 @@ def _get_openapi_operation_parameters( convert_underscores = getattr( param.field_info, "convert_underscores", - default_convert_underscores, + getattr(base_field.field_info, "convert_underscores", True), ) if ( param_type == ParamTypes.header diff --git a/tests/test_multiple_parameter_models.py b/tests/test_multiple_parameter_models.py new file mode 100644 index 000000000..4194e03fe --- /dev/null +++ b/tests/test_multiple_parameter_models.py @@ -0,0 +1,277 @@ +import pytest +from fastapi import Cookie, Depends, FastAPI, Header, Query +from fastapi._compat import PYDANTIC_V2 +from fastapi.testclient import TestClient +from pydantic import BaseModel, ConfigDict, Field + +app = FastAPI() + + +class Model(BaseModel): + field_1: int = Field(0) + + +class Model2(BaseModel): + field_2: int = Field(0) + + +class ModelNoExtra(BaseModel): + field_1: int = Field(0) + if PYDANTIC_V2: + model_config = ConfigDict(extra="forbid") + else: + + class Config: + extra = "forbid" + + +def HeaderU(*args, **kwargs): + """Header callable that ensures that convert_underscores is False.""" + return Header(*args, convert_underscores=False, **kwargs) + + +for param in (Query, Header, HeaderU, Cookie): + # Generates 4 views for all three Query, Header, and Cookie params: + # i.e. /query-depends/, /query-arguments/, /query-argument/, /query-models/ for query + + def dependency(field_2: int = param(0, title="Field 2")): + return field_2 + + @app.get(f"/{param.__name__.lower()}-depends/") + async def with_depends(model1: Model = param(), dependency=Depends(dependency)): + """Model1 is specified via Query()/Header()/Cookie() and Model2 through Depends""" + return {"field_1": model1.field_1, "field_2": dependency} + + @app.get(f"/{param.__name__.lower()}-arguments/") + async def with_argument( + field_1: int = param(0, title="Field 1"), + field_2: int = param(0, title="Field 2"), + ): + """Model1 and Model2 are specified as direct arguments (sanity check)""" + return {"field_1": field_1, "field_2": field_2} + + @app.get(f"/{param.__name__.lower()}-argument/") + async def with_model_and_argument( + model1: Model = param(), field_2: int = param(0, title="Field 2") + ): + """Model1 is specified via Query()/Header()/Cookie() and Model2 as direct argument""" + return {"field_1": model1.field_1, "field_2": field_2} + + @app.get(f"/{param.__name__.lower()}-models/") + async def with_models(model1: Model = param(), model2: Model2 = param()): + """Model1 and Model2 are specified via Query()/Header()/Cookie()""" + return {"field_1": model1.field_1, "field_2": model2.field_2} + + +@app.get("/mixed/") +async def mixed_model_sources(model1: Model = Query(), model2: Model2 = Header()): + """Model1 is specified as Query(), Model2 as Header()""" + return {"field_1": model1.field_1, "field_2": model2.field_2} + + +@app.get("/duplicate/") +async def duplicate_name(model: Model = Query(), same_model: Model = Query()): + """Model1 is specified twice in Query()""" + return {"field_1": model.field_1, "duplicate": same_model.field_1} + + +@app.get("/duplicate2/") +async def duplicate_name2(model: Model = Query(), same_model: Model = Header()): + """Model1 is specified twice, once in Query(), once in Header()""" + return {"field_1": model.field_1, "duplicate": same_model.field_1} + + +@app.get("/duplicate-no-extra/") +async def duplicate_name_no_extra( + model: Model = Query(), same_model: ModelNoExtra = Query() +): + """Uses Model and ModelNoExtra, but they have overlapping names""" + return {"field_1": model.field_1, "duplicate": same_model.field_1} + + +@app.get("/no-extra/") +async def no_extra(model1: ModelNoExtra = Query(), model2: Model2 = Query()): + """Uses Model2 and ModelNoExtra, but they don't have overlapping names""" + pass # pragma: nocover + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "path", + ["/query-depends/", "/query-arguments/", "/query-argument/", "/query-models/"], +) +def test_query_depends(path): + response = client.get(path, params={"field_1": 0, "field_2": 1}) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "field_2": 1} + + +@pytest.mark.parametrize( + "path", + ["/header-depends/", "/header-arguments/", "/header-argument/", "/header-models/"], +) +def test_header_depends(path): + response = client.get(path, headers={"field-1": "0", "field-2": "1"}) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "field_2": 1} + + +@pytest.mark.parametrize( + "path", + [ + "/headeru-depends/", + "/headeru-arguments/", + "/headeru-argument/", + "/headeru-models/", + ], +) +def test_headeru_depends(path): + response = client.get(path, headers={"field_1": "0", "field_2": "1"}) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "field_2": 1} + + +@pytest.mark.parametrize( + "path", + ["/cookie-depends/", "/cookie-arguments/", "/cookie-argument/", "/cookie-models/"], +) +def test_cookie_depends(path): + client.cookies = {"field_1": "0", "field_2": "1"} + response = client.get(path) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "field_2": 1} + + +def test_mixed(): + response = client.get("/mixed/", params={"field_1": 0}, headers={"field-2": "1"}) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "field_2": 1} + + +@pytest.mark.parametrize( + "path", + ["/duplicate/", "/duplicate2/", "/duplicate-no-extra/"], +) +def test_duplicate_name(path): + response = client.get(path, params={"field_1": 0}) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "duplicate": 0} + + +def test_no_extra(): + response = client.get("/no-extra/", params={"field_1": 0, "field_2": 1}) + assert response.status_code == 422 + if PYDANTIC_V2: + assert response.json() == { + "detail": [ + { + "input": "1", + "loc": ["query", "field_2"], + "msg": "Extra inputs are not permitted", + "type": "extra_forbidden", + } + ] + } + else: + assert response.json() == { + "detail": [ + { + "loc": ["query", "field_2"], + "msg": "extra fields not permitted", + "type": "value_error.extra", + } + ] + } + + +@pytest.mark.parametrize( + ("path", "in_", "convert_underscores"), + [ + ("/query-depends/", "query", False), + ("/query-arguments/", "query", False), + ("/query-argument/", "query", False), + ("/query-models/", "query", False), + ("/header-depends/", "header", True), + ("/header-arguments/", "header", True), + ("/header-argument/", "header", True), + ("/header-models/", "header", True), + ("/headeru-depends/", "header", False), + ("/headeru-arguments/", "header", False), + ("/headeru-argument/", "header", False), + ("/headeru-models/", "header", False), + ("/cookie-depends/", "cookie", False), + ("/cookie-arguments/", "cookie", False), + ("/cookie-argument/", "cookie", False), + ("/cookie-models/", "cookie", False), + ], +) +def test_parameters_openapi_schema(path, in_, convert_underscores): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()["paths"][path]["get"]["parameters"] == [ + { + "name": "field-1" if convert_underscores else "field_1", + "in": in_, + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, + }, + { + "name": "field-2" if convert_underscores else "field_2", + "in": in_, + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 2"}, + }, + ] + + +def test_parameters_openapi_mixed(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()["paths"]["/mixed/"]["get"]["parameters"] == [ + { + "name": "field_1", + "in": "query", + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, + }, + { + "name": "field-2", + "in": "header", + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 2"}, + }, + ] + + +def test_parameters_openapi_duplicate_name(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()["paths"]["/duplicate/"]["get"]["parameters"] == [ + { + "name": "field_1", + "in": "query", + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, + }, + ] + + +def test_parameters_openapi_duplicate_name2(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()["paths"]["/duplicate2/"]["get"]["parameters"] == [ + { + "name": "field_1", + "in": "query", + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, + }, + { + "name": "field-1", + "in": "header", + "required": False, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, + }, + ]