diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 227ad837d..9fc2cc3dd 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -50,7 +50,7 @@ Url: Type[Any] if PYDANTIC_V2: from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError - from pydantic import TypeAdapter + from pydantic import RootModel, TypeAdapter from pydantic import ValidationError as ValidationError from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] GetJsonSchemaHandler as GetJsonSchemaHandler, @@ -292,6 +292,11 @@ if PYDANTIC_V2: for name, field_info in model.model_fields.items() ] + def root_model_inner_type(annotation: Any) -> Any: + if lenient_issubclass(annotation, RootModel): + return annotation.model_fields["root"].annotation + return None + else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 @@ -494,12 +499,21 @@ else: ) def is_scalar_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_scalar(inner) + return is_pv1_scalar_field(field) def is_sequence_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_sequence(inner) + return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] def is_scalar_sequence_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_scalar_sequence(inner) + return is_pv1_scalar_sequence_field(field) def is_bytes_field(field: ModelField) -> bool: @@ -530,6 +544,14 @@ else: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: return list(model.__fields__.values()) # type: ignore[attr-defined] + def root_model_inner_type(annotation: Any) -> Any: + if ( + lenient_issubclass(annotation, BaseModel) + and "__root__" in annotation.__fields__ + ): + return annotation.__fields__["__root__"].annotation + return None + def _regenerate_error_with_loc( *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] @@ -549,6 +571,9 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_sequence(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: for arg in get_args(annotation): @@ -572,7 +597,25 @@ def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: ) +def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + return any(field_annotation_is_root_model(arg) for arg in get_args(annotation)) + + return root_model_inner_type(annotation) is not None + + +def field_annotation_has_submodel_fields(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_has_submodel_fields(inner) + + return lenient_issubclass(annotation, BaseModel) + + def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_complex(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) @@ -586,11 +629,17 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_scalar(annotation: Any) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_scalar(inner) + # handle Ellipsis here to make tuple[int, ...] work nicely return annotation is Ellipsis or not field_annotation_is_complex(annotation) def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_scalar_sequence(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: at_least_one_scalar_sequence = False diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..8cef138fd 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -30,6 +30,8 @@ from fastapi._compat import ( copy_field_info, create_body_model, evaluate_forwardref, + field_annotation_has_submodel_fields, + field_annotation_is_root_model, field_annotation_is_scalar, get_annotation_from_field_info, get_cached_model_fields, @@ -213,7 +215,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: if not fields: return fields first_field = fields[0] - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): + if len(fields) == 1 and field_annotation_has_submodel_fields(first_field.type_): fields_to_extract = get_cached_model_fields(first_field.type_) return fields_to_extract return fields @@ -454,7 +456,11 @@ def analyze_param( type_annotation ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=use_annotation, default=default_value) - elif not field_annotation_is_scalar(annotation=type_annotation): + elif ( + not field_annotation_is_scalar(type_annotation) + # Root models by default regarded as bodies + or field_annotation_is_root_model(type_annotation) + ): field_info = params.Body(annotation=use_annotation, default=default_value) else: field_info = params.Query(annotation=use_annotation, default=default_value) @@ -752,8 +758,9 @@ def request_params_to_args( single_not_embedded_field = False default_convert_underscores = True if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True + if field_annotation_has_submodel_fields(first_field.type_): + fields_to_extract = get_cached_model_fields(first_field.type_) + single_not_embedded_field = True # If headers are in a Pydantic model, the way to disable convert_underscores # would be with Header(convert_underscores=False) at the Pydantic model level default_convert_underscores = getattr( @@ -916,7 +923,9 @@ async def request_body_to_args( fields_to_extract: List[ModelField] = body_fields - if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel): + if single_not_embedded_field and field_annotation_has_submodel_fields( + first_field.type_ + ): fields_to_extract = get_cached_model_fields(first_field.type_) if isinstance(received_body, FormData): diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 451ea0760..bc761a0d5 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -220,7 +220,7 @@ def jsonable_encoder( encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: encoders.update(custom_encoder) - obj_dict = _model_dump( + serialized = _model_dump( obj, mode="json", include=include, @@ -230,10 +230,16 @@ def jsonable_encoder( exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] + + if ( + not PYDANTIC_V2 + and isinstance(serialized, dict) + and "__root__" in serialized + ): + serialized = serialized["__root__"] + return jsonable_encoder( - obj_dict, + serialized, exclude_none=exclude_none, exclude_defaults=exclude_defaults, # TODO: remove when deprecating Pydantic v1 diff --git a/tests/test_root_model.py b/tests/test_root_model.py new file mode 100644 index 000000000..78a568e10 --- /dev/null +++ b/tests/test_root_model.py @@ -0,0 +1,519 @@ +from datetime import date +from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union + +import pytest +from dirty_equals import IsDict +from fastapi import Body, FastAPI, Header, Path, Query +from fastapi._compat import PYDANTIC_V2 +from fastapi.testclient import TestClient +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from .utils import needs_pydanticv2 + +app = FastAPI() + +if PYDANTIC_V2: + from pydantic import RootModel +else: + from pydantic import BaseModel + + T = TypeVar("T") + + class RootModel(Generic[T]): + def __class_getitem__(cls, type_): + class Inner(BaseModel): + __root__: type_ + + return Inner + + +Basic = RootModel[int] + + +class DictWrap(RootModel[Dict[str, int]]): + pass + + +if PYDANTIC_V2: + from pydantic import ConfigDict, Field, field_validator, model_serializer + + class FieldWrap(RootModel[str]): + model_config = ConfigDict( + json_schema_extra={"description": "parameter starts with bar_"} + ) + root: str = Field(pattern=r"^bar_.*$") + + class CustomParsed(RootModel[str]): + model_config = ConfigDict( + json_schema_extra={"description": "parameter starts with foo_"} + ) + + @field_validator("root") + @classmethod + def validator(cls, v: str) -> str: + if not v.startswith("foo_"): + raise ValueError("must start with foo_") + return v + + @model_serializer + def serialize(self): + return ( + self.root + if self.root.endswith("_serialized") + else f"{self.root}_serialized" + ) + + def parse(self): + return self.root[len("foo_") :] # removeprefix +else: + from pydantic import validator + + class FieldWrap(BaseModel): + class Config: + schema_extra = {"description": "parameter starts with bar_"} + + __root__: str = Field(regex=r"^bar_.*$") + + class CustomParsed(BaseModel): + class Config: + schema_extra = {"description": "parameter starts with foo_"} + + __root__: str + + @validator("__root__") + @classmethod + def validator(cls, v: str) -> str: + if not v.startswith("foo_"): + raise ValueError("must start with foo_") + return v + + def dict(self, **kwargs: Any) -> Dict[str, Any]: + return { + "__root__": self.__root__ + if self.__root__.endswith("_serialized") + else f"{self.__root__}_serialized" + } + + def parse(self): + return self.__root__[len("foo_") :] # removeprefix + + +@app.get("/query/basic") +def query_basic(q: Annotated[Basic, Query()]): + return {"q": q} + + +@app.get("/query/fieldwrap") +def query_fieldwrap(q: Annotated[FieldWrap, Query()]): + return {"q": q} + + +@app.get("/query/customparsed") +def query_customparsed(q: Annotated[CustomParsed, Query()]): + return {"q": q.parse()} + + +@app.get("/path/basic/{p}") +def path_basic(p: Annotated[Basic, Path()]): + return {"p": p} + + +@app.get("/path/fieldwrap/{p}") +def path_fieldwrap(p: Annotated[FieldWrap, Path()]): + return {"p": p} + + +@app.get("/path/customparsed/{p}") +def path_customparsed(p: Annotated[CustomParsed, Path()]): + return {"p": p.parse()} + + +@app.post("/body/basic") +def body_basic(b: Annotated[Basic, Body()]): + return {"b": b} + + +@app.post("/body/fieldwrap") +def body_fieldwrap(b: Annotated[FieldWrap, Body()]): + return {"b": b} + + +@app.post("/body/customparsed") +def body_customparsed(b: Annotated[CustomParsed, Body()]): + return {"b": b.parse()} + + +@app.post("/body_default/basic") +def body_default_basic(b: Basic): + return {"b": b} + + +@app.post("/body_default/fieldwrap") +def body_default_fieldwrap(b: FieldWrap): + return {"b": b} + + +@app.post("/body_default/customparsed") +def body_default_customparsed(b: CustomParsed): + return {"b": b.parse()} + + +@app.get("/echo/basic") +def echo_basic(q: Annotated[Basic, Query()]) -> Basic: + return q + + +@app.get("/echo/fieldwrap") +def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap: + return q + + +@app.get("/echo/customparsed") +def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed: + return q + + +@app.post("/echo/dictwrap") +def echo_dictwrap(b: DictWrap) -> DictWrap: + return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"] + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "url, response_json, request_body", + [ + ("/query/basic?q=42", {"q": 42}, None), + ("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None), + ("/query/customparsed?q=foo_bar", {"q": "bar"}, None), + ("/path/basic/42", {"p": 42}, None), + ("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None), + ("/path/customparsed/foo_bar", {"p": "bar"}, None), + ("/body/basic", {"b": 42}, "42"), + ("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), + ("/body/customparsed", {"b": "bar"}, "foo_bar"), + ("/body_default/basic", {"b": 42}, "42"), + ("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"), + ("/body_default/customparsed", {"b": "bar"}, "foo_bar"), + ("/echo/basic?q=42", 42, None), + ("/echo/fieldwrap?q=bar_baz", "bar_baz", None), + ("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), + ("/echo/dictwrap", {"test": 1}, {"test": 1}), + ], +) +def test_root_model_200(url: str, response_json: Any, request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 200, response.text + assert response.json() == response_json + + +@pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]]) +@pytest.mark.parametrize("outer", [None, List, Sequence]) +def test2_root_model_200__basic(type_: Type, outer: Optional[Type]): + inner = outer[type_] if outer else type_ + Model = RootModel[inner] + + app2 = FastAPI() + + if outer is None: + + @app2.get("/path/{p}") + def path_basic2(p: Annotated[Model, Path()]) -> Model: + return p + else: + with pytest.raises( + AssertionError, match="Path params must be of one of the supported types" + ): + + @app2.get("/path/{p}") + def path_basic2(p: Annotated[Model, Path()]) -> Model: + return p # pragma: nocover + + @app2.get("/query") + def query_basic2(q: Annotated[Model, Query()]) -> Model: + return q + + @app2.get("/header") + def path_basic2(h: Annotated[Model, Header()]) -> Model: + return h + + @app2.post("/body") + def body_basic2(b: Annotated[Model, Body()]) -> Model: + return b + + client2 = TestClient(app2) + + if outer is None: + expected = "42" if type_ in (str, bytes) else 42 + else: + expected = ["42", "43"] if type_ in (str, bytes) else [42, 43] + + if outer is None: + assert client2.get("/path/42").json() == expected + assert client2.get("/query?q=42").json() == expected + assert client2.get("/header", headers={"h": "42"}).json() == expected + else: + assert client2.get("/path/42").json()["detail"] == "Not Found" + assert client2.get("/query?q=42&q=43").json() == expected + assert ( + client2.get("/header", headers=[("h", "42"), ("h", "43")]).json() + == expected + ) + assert client2.post("/body", json=expected).json() == expected + + +@pytest.mark.parametrize("left", [RootModel[date], date]) +@pytest.mark.parametrize("right", [RootModel[int], int]) +@needs_pydanticv2 +def test_root_model_union(left: Any, right: Any): + Model = Union[left, right] + + app2 = FastAPI() + + @app2.get("/query") + def handler1(q: Annotated[Model, Query()]): + return {"q": q} + + @app2.post("/body") + def handler2(b: Annotated[Model, Body()]): + return {"b": b} + + client2 = TestClient(app2) + for val in ["2025-01-02", 42]: + response = client2.get(f"/query?q={val}") + assert response.status_code == 200, response.text + assert response.json() == {"q": val} + response = client2.post("/body", json=val) + assert response.status_code == 200, response.text + assert response.json() == {"b": val} + + response = client2.get("/query?q=bad") + assert response.status_code == 422, response.text + assert {detail["msg"] for detail in response.json()["detail"]} == { + "Input should be a valid date or datetime, input is too short", + "Input should be a valid integer, unable to parse string as an integer", + } + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/basic?q=my_bad_not_int", ["query", "q"], None), + ("/path/basic/my_bad_not_int", ["path", "p"], None), + ("/body/basic", ["body"], "my_bad_not_int"), + ("/body_default/basic", ["body"], "my_bad_not_int"), + ], +) +def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": error_path, + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "my_bad_not_int", + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + ) + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), + ("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), + ("/body/fieldwrap", ["body"], "my_bad_prefix_val"), + ("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"), + ], +) +def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "string_pattern_mismatch", + "loc": error_path, + "msg": "String should match pattern '^bar_.*$'", + "input": "my_bad_prefix_val", + "ctx": {"pattern": "^bar_.*$"}, + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": 'string does not match regex "^bar_.*$"', + "type": "value_error.str.regex", + "ctx": {"pattern": "^bar_.*$"}, + } + ] + } + ) + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), + ("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), + ("/body/customparsed", ["body"], "my_bad_prefix_val"), + ("/body_default/customparsed", ["body"], "my_bad_prefix_val"), + ], +) +def test_root_model_customparsed_422( + url: str, error_path: List[str], request_body: Any +): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "value_error", + "loc": error_path, + "msg": "Value error, must start with foo_", + "input": "my_bad_prefix_val", + "ctx": {"error": {}}, + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": "must start with foo_", + "type": "value_error", + } + ] + } + ) + + +def test_root_model_dictwrap_422(): + response = client.post("/echo/dictwrap", json={"test": "fail_not_int"}) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": ["body", "test"], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "fail_not_int", + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": ["body", "__root__", "test"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + ) + + +@pytest.mark.parametrize( + "model, path_name, expected_model_schema", + [ + (Basic, "basic", {"type": "integer"}), + ( + FieldWrap, + "fieldwrap", + { + "type": "string", + "pattern": "^bar_.*$", + "description": "parameter starts with bar_", + }, + ), + ( + CustomParsed, + "customparsed", + { + "type": "string", + "description": "parameter starts with foo_", + }, + ), + ], +) +def test_openapi_schema( + model: Type, path_name: str, expected_model_schema: Dict[str, Any] +): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + paths = response.json()["paths"] + ref_name = model.__name__.replace("[", "_").replace("]", "_") + schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}} + + assert paths[f"/query/{path_name}"]["get"]["parameters"] == [ + {"in": "query", "name": "q", "required": True, **schema_ref} + ] + assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [ + {"in": "path", "name": "p", "required": True, **schema_ref} + ] + assert paths[f"/body/{path_name}"]["post"]["requestBody"] == { + "content": {"application/json": schema_ref}, + "required": True, + } + assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == { + "content": {"application/json": schema_ref}, + "required": True, + } + assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == { + "content": {"application/json": schema_ref}, + "description": "Successful Response", + } + + model_schema = response.json()["components"]["schemas"][ref_name] + model_schema.pop("title") + assert model_schema == expected_model_schema + + +def test_openapi_schema_dictwrap(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + operation = response.json()["paths"]["/echo/dictwrap"]["post"] + ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}} + assert operation["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } + assert operation["responses"]["200"] == { + "content": {"application/json": ref}, + "description": "Successful Response", + } + assert response.json()["components"]["schemas"]["DictWrap"] == { + "title": "DictWrap", + "type": "object", + "additionalProperties": {"type": "integer"}, + }