diff --git a/fastapi/_compat.py b/fastapi/_compat.py index bf9e46e68..9fc2cc3dd 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -239,7 +239,11 @@ if PYDANTIC_V2: return field_mapping, definitions # type: ignore[return-value] def is_scalar_field(field: ModelField) -> bool: - return field_annotation_is_scalar(field.field_info.annotation) + from fastapi import params + + return field_annotation_is_scalar( + field.field_info.annotation + ) and not isinstance(field.field_info, params.Body) def is_sequence_field(field: ModelField) -> bool: return field_annotation_is_sequence(field.field_info.annotation) @@ -402,12 +406,6 @@ else: def is_pv1_scalar_field(field: ModelField) -> bool: from fastapi import params - if ( - lenient_issubclass(field.type_, BaseModel) - and "__root__" in field.type_.__fields__ - ): - return is_pv1_scalar_field(field.type_.__fields__["__root__"]) - field_info = field.field_info if not ( field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] @@ -501,12 +499,21 @@ else: ) def is_scalar_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_scalar(inner) + return is_pv1_scalar_field(field) def is_sequence_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_sequence(inner) + return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] def is_scalar_sequence_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_scalar_sequence(inner) + return is_pv1_scalar_sequence_field(field) def is_bytes_field(field: ModelField) -> bool: @@ -564,6 +571,9 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_sequence(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: for arg in get_args(annotation): @@ -595,14 +605,21 @@ def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool: return root_model_inner_type(annotation) is not None +def field_annotation_has_submodel_fields(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_has_submodel_fields(inner) + + return lenient_issubclass(annotation, BaseModel) + + def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_complex(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) - if inner := root_model_inner_type(annotation): - return field_annotation_is_complex(inner) - return ( _annotation_is_complex(annotation) or _annotation_is_complex(origin) @@ -612,11 +629,17 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_scalar(annotation: Any) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_scalar(inner) + # handle Ellipsis here to make tuple[int, ...] work nicely return annotation is Ellipsis or not field_annotation_is_complex(annotation) def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_scalar_sequence(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: at_least_one_scalar_sequence = False diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 359004601..8cef138fd 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -30,6 +30,7 @@ from fastapi._compat import ( copy_field_info, create_body_model, evaluate_forwardref, + field_annotation_has_submodel_fields, field_annotation_is_root_model, field_annotation_is_scalar, get_annotation_from_field_info, @@ -214,7 +215,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: if not fields: return fields first_field = fields[0] - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): + if len(fields) == 1 and field_annotation_has_submodel_fields(first_field.type_): fields_to_extract = get_cached_model_fields(first_field.type_) return fields_to_extract return fields @@ -757,8 +758,9 @@ def request_params_to_args( single_not_embedded_field = False default_convert_underscores = True if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True + if field_annotation_has_submodel_fields(first_field.type_): + fields_to_extract = get_cached_model_fields(first_field.type_) + single_not_embedded_field = True # If headers are in a Pydantic model, the way to disable convert_underscores # would be with Header(convert_underscores=False) at the Pydantic model level default_convert_underscores = getattr( @@ -921,7 +923,9 @@ async def request_body_to_args( fields_to_extract: List[ModelField] = body_fields - if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel): + if single_not_embedded_field and field_annotation_has_submodel_fields( + first_field.type_ + ): fields_to_extract = get_cached_model_fields(first_field.type_) if isinstance(received_body, FormData): diff --git a/tests/test_root_model.py b/tests/test_root_model.py index 5c91bc10a..78a568e10 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -1,18 +1,42 @@ -from typing import Any, Dict, List, Type, Union +from datetime import date +from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union import pytest from dirty_equals import IsDict -from fastapi import Body, FastAPI, Path, Query +from fastapi import Body, FastAPI, Header, Path, Query from fastapi._compat import PYDANTIC_V2 from fastapi.testclient import TestClient -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from .utils import needs_pydanticv2 app = FastAPI() if PYDANTIC_V2: - from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer + from pydantic import RootModel +else: + from pydantic import BaseModel + + T = TypeVar("T") + + class RootModel(Generic[T]): + def __class_getitem__(cls, type_): + class Inner(BaseModel): + __root__: type_ + + return Inner + + +Basic = RootModel[int] + + +class DictWrap(RootModel[Dict[str, int]]): + pass - Basic = RootModel[int] + +if PYDANTIC_V2: + from pydantic import ConfigDict, Field, field_validator, model_serializer class FieldWrap(RootModel[str]): model_config = ConfigDict( @@ -42,14 +66,8 @@ if PYDANTIC_V2: def parse(self): return self.root[len("foo_") :] # removeprefix - - class DictWrap(RootModel[Dict[str, int]]): - pass else: - from pydantic import Field, validator - - class Basic(BaseModel): - __root__: int + from pydantic import validator class FieldWrap(BaseModel): class Config: @@ -80,52 +98,49 @@ else: def parse(self): return self.__root__[len("foo_") :] # removeprefix - class DictWrap(BaseModel): - __root__: Dict[str, int] - @app.get("/query/basic") -def query_basic(q: Basic = Query()): +def query_basic(q: Annotated[Basic, Query()]): return {"q": q} @app.get("/query/fieldwrap") -def query_fieldwrap(q: FieldWrap = Query()): +def query_fieldwrap(q: Annotated[FieldWrap, Query()]): return {"q": q} @app.get("/query/customparsed") -def query_customparsed(q: CustomParsed = Query()): +def query_customparsed(q: Annotated[CustomParsed, Query()]): return {"q": q.parse()} @app.get("/path/basic/{p}") -def path_basic(p: Basic = Path()): +def path_basic(p: Annotated[Basic, Path()]): return {"p": p} @app.get("/path/fieldwrap/{p}") -def path_fieldwrap(p: FieldWrap = Path()): +def path_fieldwrap(p: Annotated[FieldWrap, Path()]): return {"p": p} @app.get("/path/customparsed/{p}") -def path_customparsed(p: CustomParsed = Path()): +def path_customparsed(p: Annotated[CustomParsed, Path()]): return {"p": p.parse()} @app.post("/body/basic") -def body_basic(b: Basic = Body()): +def body_basic(b: Annotated[Basic, Body()]): return {"b": b} @app.post("/body/fieldwrap") -def body_fieldwrap(b: FieldWrap = Body()): +def body_fieldwrap(b: Annotated[FieldWrap, Body()]): return {"b": b} @app.post("/body/customparsed") -def body_customparsed(b: CustomParsed = Body()): +def body_customparsed(b: Annotated[CustomParsed, Body()]): return {"b": b.parse()} @@ -145,17 +160,17 @@ def body_default_customparsed(b: CustomParsed): @app.get("/echo/basic") -def echo_basic(q: Basic = Query()) -> Basic: +def echo_basic(q: Annotated[Basic, Query()]) -> Basic: return q @app.get("/echo/fieldwrap") -def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: +def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap: return q @app.get("/echo/customparsed") -def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: +def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed: return q @@ -194,44 +209,92 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any): assert response.json() == response_json -def test_root_model_union(): - if PYDANTIC_V2: - from pydantic import RootModel +@pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]]) +@pytest.mark.parametrize("outer", [None, List, Sequence]) +def test2_root_model_200__basic(type_: Type, outer: Optional[Type]): + inner = outer[type_] if outer else type_ + Model = RootModel[inner] + + app2 = FastAPI() + + if outer is None: + + @app2.get("/path/{p}") + def path_basic2(p: Annotated[Model, Path()]) -> Model: + return p + else: + with pytest.raises( + AssertionError, match="Path params must be of one of the supported types" + ): + + @app2.get("/path/{p}") + def path_basic2(p: Annotated[Model, Path()]) -> Model: + return p # pragma: nocover + + @app2.get("/query") + def query_basic2(q: Annotated[Model, Query()]) -> Model: + return q + + @app2.get("/header") + def path_basic2(h: Annotated[Model, Header()]) -> Model: + return h - RootModelInt = RootModel[int] - RootModelStr = RootModel[str] + @app2.post("/body") + def body_basic2(b: Annotated[Model, Body()]) -> Model: + return b + + client2 = TestClient(app2) + + if outer is None: + expected = "42" if type_ in (str, bytes) else 42 else: + expected = ["42", "43"] if type_ in (str, bytes) else [42, 43] + + if outer is None: + assert client2.get("/path/42").json() == expected + assert client2.get("/query?q=42").json() == expected + assert client2.get("/header", headers={"h": "42"}).json() == expected + else: + assert client2.get("/path/42").json()["detail"] == "Not Found" + assert client2.get("/query?q=42&q=43").json() == expected + assert ( + client2.get("/header", headers=[("h", "42"), ("h", "43")]).json() + == expected + ) + assert client2.post("/body", json=expected).json() == expected - class RootModelInt(BaseModel): - __root__: int - class RootModelStr(BaseModel): - __root__: str +@pytest.mark.parametrize("left", [RootModel[date], date]) +@pytest.mark.parametrize("right", [RootModel[int], int]) +@needs_pydanticv2 +def test_root_model_union(left: Any, right: Any): + Model = Union[left, right] app2 = FastAPI() - @app2.post("/union") - def union_handler(b: Union[RootModelInt, RootModelStr]): + @app2.get("/query") + def handler1(q: Annotated[Model, Query()]): + return {"q": q} + + @app2.post("/body") + def handler2(b: Annotated[Model, Body()]): return {"b": b} client2 = TestClient(app2) - for body in [42, "foo"]: - response = client2.post("/union", json=body) + for val in ["2025-01-02", 42]: + response = client2.get(f"/query?q={val}") + assert response.status_code == 200, response.text + assert response.json() == {"q": val} + response = client2.post("/body", json=val) assert response.status_code == 200, response.text - assert response.json() == {"b": body} + assert response.json() == {"b": val} - response = client2.post("/union", json=["bad_list"]) + response = client2.get("/query?q=bad") assert response.status_code == 422, response.text - if PYDANTIC_V2: - assert {detail["msg"] for detail in response.json()["detail"]} == { - "Input should be a valid integer", - "Input should be a valid string", - } - else: - assert {detail["msg"] for detail in response.json()["detail"]} == { - "value is not a valid integer", - "str type expected", - } + assert {detail["msg"] for detail in response.json()["detail"]} == { + "Input should be a valid date or datetime, input is too short", + "Input should be a valid integer, unable to parse string as an integer", + } @pytest.mark.parametrize(