Browse Source

Merge 66b486bf34 into 6df50d40fe

pull/12944/merge
Yurii Karabas 3 days ago
committed by GitHub
parent
commit
f8d943704a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 66
      fastapi/dependencies/utils.py
  2. 137
      tests/test_multiple_params_models.py

66
fastapi/dependencies/utils.py

@ -212,11 +212,14 @@ def get_flat_dependant(
def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
if not fields:
return fields
first_field = fields[0]
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
return fields_to_extract
return fields
fields_to_extract = []
for f in fields:
if lenient_issubclass(f.type_, BaseModel):
fields_to_extract.extend(get_cached_model_fields(f.type_))
else:
fields_to_extract.append(f)
return fields_to_extract
def get_flat_params(dependant: Dependant) -> List[ModelField]:
@ -747,30 +750,28 @@ def request_params_to_args(
if not fields:
return values, errors
first_field = fields[0]
fields_to_extract = fields
single_not_embedded_field = False
default_convert_underscores = True
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
params_to_process: Dict[str, Any] = {}
fields_to_extract = [
(field, cached_field)
for field in fields
if lenient_issubclass(field.type_, BaseModel)
for cached_field in get_cached_model_fields(field.type_)
]
processed_keys = set()
for field in fields_to_extract:
for parent_field, field in fields_to_extract:
alias = None
if isinstance(received_params, Headers):
# Handle fields extracted from a Pydantic Model for a header, each field
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
convert_underscores = getattr(
field.field_info, "convert_underscores", default_convert_underscores
parent_field.field_info,
"convert_underscores",
default_convert_underscores,
)
if convert_underscores:
alias = (
@ -788,27 +789,24 @@ def request_params_to_args(
if key not in processed_keys:
params_to_process[key] = value
if single_not_embedded_field:
field_info = first_field.field_info
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
loc: Tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
loc = (field_info.in_.value, field.alias)
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
if lenient_issubclass(field.type_, BaseModel):
loc: Tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=field, value=params_to_process, values=values, loc=loc
)
else:
value = _get_multidict_value(field, received_params)
loc = (field_info.in_.value, field.alias)
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
if errors_:
errors.extend(errors_)
else:

137
tests/test_multiple_params_models.py

@ -0,0 +1,137 @@
from typing import Any, Callable
import pytest
from fastapi import APIRouter, Cookie, FastAPI, Header, Query, status
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing_extensions import Annotated
app = FastAPI()
client = TestClient(app)
class NameModel(BaseModel):
name: str
class AgeModel(BaseModel):
age: int
def add_routes(
in_: Callable[..., Any],
prefix: str,
) -> None:
router = APIRouter(prefix=prefix)
@router.get("/models")
async def route_models(
name_model: Annotated[NameModel, in_()],
age_model: Annotated[AgeModel, in_()],
):
return {
"name": name_model.name,
"age": age_model.age,
}
@router.get("/mixed")
async def route_mixed(
name_model: Annotated[NameModel, in_()],
age: Annotated[int, in_()],
):
return {
"name": name_model.name,
"age": age,
}
app.include_router(router)
add_routes(Query, "/query")
add_routes(Header, "/header")
add_routes(Cookie, "/cookie")
@pytest.mark.parametrize(
("in_", "prefix", "call_arg"),
[
(Query, "/query", "params"),
(Header, "/header", "headers"),
(Cookie, "/cookie", "cookies"),
],
ids=[
"query",
"header",
"cookie",
],
)
@pytest.mark.parametrize(
"type_",
[
"models",
"mixed",
],
ids=[
"models",
"mixed",
],
)
def test_multiple_params(in_, prefix, call_arg, type_):
params = {"name": "John", "age": "42"}
kwargs = {}
if call_arg == "cookies":
client.cookies = params
else:
kwargs[call_arg] = params
response = client.get(f"{prefix}/{type_}", **kwargs)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"name": "John", "age": 42}
@pytest.mark.parametrize(
("prefix", "in_"),
[
("/query", "query"),
("/header", "header"),
("/cookie", "cookie"),
],
ids=[
"query",
"header",
"cookie",
],
)
@pytest.mark.parametrize(
"type_",
[
"models",
"mixed",
],
ids=[
"models",
"mixed",
],
)
def test_openapi_schema(prefix, in_, type_):
response = client.get("/openapi.json")
assert response.status_code == status.HTTP_200_OK
schema = response.json()
assert schema["paths"][f"{prefix}/{type_}"]["get"]["parameters"] == [
{
"required": True,
"in": in_,
"name": "name",
"schema": {"title": "Name", "type": "string"},
},
{
"required": True,
"in": in_,
"name": "age",
"schema": {"title": "Age", "type": "integer"},
},
]
Loading…
Cancel
Save