Browse Source

Handle get_cached_model_fields. Fix type checks. Add more tests

pull/11306/head
Markus Sintonen 2 weeks ago
parent
commit
c4daa7abab
  1. 43
      fastapi/_compat.py
  2. 8
      fastapi/dependencies/utils.py
  3. 163
      tests/test_root_model.py

43
fastapi/_compat.py

@ -239,7 +239,11 @@ if PYDANTIC_V2:
return field_mapping, definitions # type: ignore[return-value] return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
return field_annotation_is_scalar(field.field_info.annotation) from fastapi import params
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool: def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation) return field_annotation_is_sequence(field.field_info.annotation)
@ -402,12 +406,6 @@ else:
def is_pv1_scalar_field(field: ModelField) -> bool: def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params from fastapi import params
if (
lenient_issubclass(field.type_, BaseModel)
and "__root__" in field.type_.__fields__
):
return is_pv1_scalar_field(field.type_.__fields__["__root__"])
field_info = field.field_info field_info = field.field_info
if not ( if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
@ -501,12 +499,21 @@ else:
) )
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
if (inner := root_model_inner_type(field.type_)) is not None:
return field_annotation_is_scalar(inner)
return is_pv1_scalar_field(field) return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool: def is_sequence_field(field: ModelField) -> bool:
if (inner := root_model_inner_type(field.type_)) is not None:
return field_annotation_is_sequence(inner)
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool: def is_scalar_sequence_field(field: ModelField) -> bool:
if (inner := root_model_inner_type(field.type_)) is not None:
return field_annotation_is_scalar_sequence(inner)
return is_pv1_scalar_sequence_field(field) return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool: def is_bytes_field(field: ModelField) -> bool:
@ -564,6 +571,9 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_sequence(inner)
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
for arg in get_args(annotation): for arg in get_args(annotation):
@ -595,14 +605,21 @@ def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool:
return root_model_inner_type(annotation) is not None return root_model_inner_type(annotation) is not None
def field_annotation_has_submodel_fields(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_has_submodel_fields(inner)
return lenient_issubclass(annotation, BaseModel)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_complex(inner)
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
if inner := root_model_inner_type(annotation):
return field_annotation_is_complex(inner)
return ( return (
_annotation_is_complex(annotation) _annotation_is_complex(annotation)
or _annotation_is_complex(origin) or _annotation_is_complex(origin)
@ -612,11 +629,17 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar(annotation: Any) -> bool: def field_annotation_is_scalar(annotation: Any) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_scalar(inner)
# handle Ellipsis here to make tuple[int, ...] work nicely # handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation) return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_scalar_sequence(inner)
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False at_least_one_scalar_sequence = False

8
fastapi/dependencies/utils.py

@ -30,6 +30,7 @@ from fastapi._compat import (
copy_field_info, copy_field_info,
create_body_model, create_body_model,
evaluate_forwardref, evaluate_forwardref,
field_annotation_has_submodel_fields,
field_annotation_is_root_model, field_annotation_is_root_model,
field_annotation_is_scalar, field_annotation_is_scalar,
get_annotation_from_field_info, get_annotation_from_field_info,
@ -214,7 +215,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
if not fields: if not fields:
return fields return fields
first_field = fields[0] first_field = fields[0]
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): if len(fields) == 1 and field_annotation_has_submodel_fields(first_field.type_):
fields_to_extract = get_cached_model_fields(first_field.type_) fields_to_extract = get_cached_model_fields(first_field.type_)
return fields_to_extract return fields_to_extract
return fields return fields
@ -757,6 +758,7 @@ def request_params_to_args(
single_not_embedded_field = False single_not_embedded_field = False
default_convert_underscores = True default_convert_underscores = True
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
if field_annotation_has_submodel_fields(first_field.type_):
fields_to_extract = get_cached_model_fields(first_field.type_) fields_to_extract = get_cached_model_fields(first_field.type_)
single_not_embedded_field = True single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores # If headers are in a Pydantic model, the way to disable convert_underscores
@ -921,7 +923,9 @@ async def request_body_to_args(
fields_to_extract: List[ModelField] = body_fields fields_to_extract: List[ModelField] = body_fields
if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel): if single_not_embedded_field and field_annotation_has_submodel_fields(
first_field.type_
):
fields_to_extract = get_cached_model_fields(first_field.type_) fields_to_extract = get_cached_model_fields(first_field.type_)
if isinstance(received_body, FormData): if isinstance(received_body, FormData):

163
tests/test_root_model.py

@ -1,18 +1,42 @@
from typing import Any, Dict, List, Type, Union from datetime import date
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi import Body, FastAPI, Path, Query from fastapi import Body, FastAPI, Header, Path, Query
from fastapi._compat import PYDANTIC_V2 from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel from pydantic import BaseModel, Field
from typing_extensions import Annotated
from .utils import needs_pydanticv2
app = FastAPI() app = FastAPI()
if PYDANTIC_V2: if PYDANTIC_V2:
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer from pydantic import RootModel
else:
from pydantic import BaseModel
T = TypeVar("T")
class RootModel(Generic[T]):
def __class_getitem__(cls, type_):
class Inner(BaseModel):
__root__: type_
return Inner
Basic = RootModel[int]
class DictWrap(RootModel[Dict[str, int]]):
pass
Basic = RootModel[int] if PYDANTIC_V2:
from pydantic import ConfigDict, Field, field_validator, model_serializer
class FieldWrap(RootModel[str]): class FieldWrap(RootModel[str]):
model_config = ConfigDict( model_config = ConfigDict(
@ -42,14 +66,8 @@ if PYDANTIC_V2:
def parse(self): def parse(self):
return self.root[len("foo_") :] # removeprefix return self.root[len("foo_") :] # removeprefix
class DictWrap(RootModel[Dict[str, int]]):
pass
else: else:
from pydantic import Field, validator from pydantic import validator
class Basic(BaseModel):
__root__: int
class FieldWrap(BaseModel): class FieldWrap(BaseModel):
class Config: class Config:
@ -80,52 +98,49 @@ else:
def parse(self): def parse(self):
return self.__root__[len("foo_") :] # removeprefix return self.__root__[len("foo_") :] # removeprefix
class DictWrap(BaseModel):
__root__: Dict[str, int]
@app.get("/query/basic") @app.get("/query/basic")
def query_basic(q: Basic = Query()): def query_basic(q: Annotated[Basic, Query()]):
return {"q": q} return {"q": q}
@app.get("/query/fieldwrap") @app.get("/query/fieldwrap")
def query_fieldwrap(q: FieldWrap = Query()): def query_fieldwrap(q: Annotated[FieldWrap, Query()]):
return {"q": q} return {"q": q}
@app.get("/query/customparsed") @app.get("/query/customparsed")
def query_customparsed(q: CustomParsed = Query()): def query_customparsed(q: Annotated[CustomParsed, Query()]):
return {"q": q.parse()} return {"q": q.parse()}
@app.get("/path/basic/{p}") @app.get("/path/basic/{p}")
def path_basic(p: Basic = Path()): def path_basic(p: Annotated[Basic, Path()]):
return {"p": p} return {"p": p}
@app.get("/path/fieldwrap/{p}") @app.get("/path/fieldwrap/{p}")
def path_fieldwrap(p: FieldWrap = Path()): def path_fieldwrap(p: Annotated[FieldWrap, Path()]):
return {"p": p} return {"p": p}
@app.get("/path/customparsed/{p}") @app.get("/path/customparsed/{p}")
def path_customparsed(p: CustomParsed = Path()): def path_customparsed(p: Annotated[CustomParsed, Path()]):
return {"p": p.parse()} return {"p": p.parse()}
@app.post("/body/basic") @app.post("/body/basic")
def body_basic(b: Basic = Body()): def body_basic(b: Annotated[Basic, Body()]):
return {"b": b} return {"b": b}
@app.post("/body/fieldwrap") @app.post("/body/fieldwrap")
def body_fieldwrap(b: FieldWrap = Body()): def body_fieldwrap(b: Annotated[FieldWrap, Body()]):
return {"b": b} return {"b": b}
@app.post("/body/customparsed") @app.post("/body/customparsed")
def body_customparsed(b: CustomParsed = Body()): def body_customparsed(b: Annotated[CustomParsed, Body()]):
return {"b": b.parse()} return {"b": b.parse()}
@ -145,17 +160,17 @@ def body_default_customparsed(b: CustomParsed):
@app.get("/echo/basic") @app.get("/echo/basic")
def echo_basic(q: Basic = Query()) -> Basic: def echo_basic(q: Annotated[Basic, Query()]) -> Basic:
return q return q
@app.get("/echo/fieldwrap") @app.get("/echo/fieldwrap")
def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap:
return q return q
@app.get("/echo/customparsed") @app.get("/echo/customparsed")
def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed:
return q return q
@ -194,43 +209,91 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any):
assert response.json() == response_json assert response.json() == response_json
def test_root_model_union(): @pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]])
if PYDANTIC_V2: @pytest.mark.parametrize("outer", [None, List, Sequence])
from pydantic import RootModel def test2_root_model_200__basic(type_: Type, outer: Optional[Type]):
inner = outer[type_] if outer else type_
Model = RootModel[inner]
app2 = FastAPI()
RootModelInt = RootModel[int] if outer is None:
RootModelStr = RootModel[str]
@app2.get("/path/{p}")
def path_basic2(p: Annotated[Model, Path()]) -> Model:
return p
else: else:
with pytest.raises(
AssertionError, match="Path params must be of one of the supported types"
):
class RootModelInt(BaseModel): @app2.get("/path/{p}")
__root__: int def path_basic2(p: Annotated[Model, Path()]) -> Model:
return p # pragma: nocover
@app2.get("/query")
def query_basic2(q: Annotated[Model, Query()]) -> Model:
return q
@app2.get("/header")
def path_basic2(h: Annotated[Model, Header()]) -> Model:
return h
@app2.post("/body")
def body_basic2(b: Annotated[Model, Body()]) -> Model:
return b
client2 = TestClient(app2)
if outer is None:
expected = "42" if type_ in (str, bytes) else 42
else:
expected = ["42", "43"] if type_ in (str, bytes) else [42, 43]
if outer is None:
assert client2.get("/path/42").json() == expected
assert client2.get("/query?q=42").json() == expected
assert client2.get("/header", headers={"h": "42"}).json() == expected
else:
assert client2.get("/path/42").json()["detail"] == "Not Found"
assert client2.get("/query?q=42&q=43").json() == expected
assert (
client2.get("/header", headers=[("h", "42"), ("h", "43")]).json()
== expected
)
assert client2.post("/body", json=expected).json() == expected
class RootModelStr(BaseModel):
__root__: str @pytest.mark.parametrize("left", [RootModel[date], date])
@pytest.mark.parametrize("right", [RootModel[int], int])
@needs_pydanticv2
def test_root_model_union(left: Any, right: Any):
Model = Union[left, right]
app2 = FastAPI() app2 = FastAPI()
@app2.post("/union") @app2.get("/query")
def union_handler(b: Union[RootModelInt, RootModelStr]): def handler1(q: Annotated[Model, Query()]):
return {"q": q}
@app2.post("/body")
def handler2(b: Annotated[Model, Body()]):
return {"b": b} return {"b": b}
client2 = TestClient(app2) client2 = TestClient(app2)
for body in [42, "foo"]: for val in ["2025-01-02", 42]:
response = client2.post("/union", json=body) response = client2.get(f"/query?q={val}")
assert response.status_code == 200, response.text
assert response.json() == {"q": val}
response = client2.post("/body", json=val)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"b": body} assert response.json() == {"b": val}
response = client2.post("/union", json=["bad_list"]) response = client2.get("/query?q=bad")
assert response.status_code == 422, response.text assert response.status_code == 422, response.text
if PYDANTIC_V2:
assert {detail["msg"] for detail in response.json()["detail"]} == {
"Input should be a valid integer",
"Input should be a valid string",
}
else:
assert {detail["msg"] for detail in response.json()["detail"]} == { assert {detail["msg"] for detail in response.json()["detail"]} == {
"value is not a valid integer", "Input should be a valid date or datetime, input is too short",
"str type expected", "Input should be a valid integer, unable to parse string as an integer",
} }

Loading…
Cancel
Save