From 867d75b5c8d8f8dc7e8ebcbf20bd658a71e41331 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Sun, 17 Mar 2024 15:58:31 +0200 Subject: [PATCH 1/4] Support Pydantic root model as query parameter --- fastapi/_compat.py | 30 ++- fastapi/encoders.py | 10 +- tests/test_root_model.py | 393 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 427 insertions(+), 6 deletions(-) create mode 100644 tests/test_root_model.py diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 227ad837d..8f25244e9 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -50,7 +50,7 @@ Url: Type[Any] if PYDANTIC_V2: from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError - from pydantic import TypeAdapter + from pydantic import RootModel, TypeAdapter from pydantic import ValidationError as ValidationError from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] GetJsonSchemaHandler as GetJsonSchemaHandler, @@ -292,6 +292,11 @@ if PYDANTIC_V2: for name, field_info in model.model_fields.items() ] + def root_model_inner_type(annotation: Any) -> Any: + if lenient_issubclass(annotation, RootModel): + return annotation.model_fields["root"].annotation + return None + else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 @@ -401,6 +406,12 @@ else: def is_pv1_scalar_field(field: ModelField) -> bool: from fastapi import params + if ( + lenient_issubclass(field.type_, BaseModel) + and "__root__" in field.type_.__fields__ + ): + return is_pv1_scalar_field(field.type_.__fields__["__root__"]) + field_info = field.field_info if not ( field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] @@ -494,7 +505,11 @@ else: ) def is_scalar_field(field: ModelField) -> bool: - return is_pv1_scalar_field(field) + from fastapi import params + + return is_pv1_scalar_field(field) and not isinstance( + field.field_info, params.Body + ) def is_sequence_field(field: ModelField) -> bool: return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] @@ -530,6 +545,14 @@ else: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: return list(model.__fields__.values()) # type: ignore[attr-defined] + def root_model_inner_type(annotation: Any) -> Any: + if ( + lenient_issubclass(annotation, BaseModel) + and "__root__" in annotation.__fields__ + ): + return annotation.__fields__["__root__"].annotation + return None + def _regenerate_error_with_loc( *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] @@ -577,6 +600,9 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: if origin is Union or origin is UnionType: return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) + if inner := root_model_inner_type(annotation): + return field_annotation_is_complex(inner) + return ( _annotation_is_complex(annotation) or _annotation_is_complex(origin) diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 451ea0760..aa5e04fc6 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -220,7 +220,7 @@ def jsonable_encoder( encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: encoders.update(custom_encoder) - obj_dict = _model_dump( + serialized = _model_dump( obj, mode="json", include=include, @@ -230,10 +230,12 @@ def jsonable_encoder( exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] + if ( + isinstance(serialized, dict) and "__root__" in serialized + ): # TODO: remove when deprecating Pydantic v1 + serialized = serialized["__root__"] return jsonable_encoder( - obj_dict, + serialized, exclude_none=exclude_none, exclude_defaults=exclude_defaults, # TODO: remove when deprecating Pydantic v1 diff --git a/tests/test_root_model.py b/tests/test_root_model.py new file mode 100644 index 000000000..b3d2acc6d --- /dev/null +++ b/tests/test_root_model.py @@ -0,0 +1,393 @@ +from typing import Any, Dict, List, Type + +import pytest +from dirty_equals import IsDict +from fastapi import Body, FastAPI, Path, Query +from fastapi._compat import PYDANTIC_V2 +from fastapi.testclient import TestClient +from fastapi.utils import match_pydantic_error_url +from pydantic import BaseModel + +app = FastAPI() + +if PYDANTIC_V2: + from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer + + class Basic(RootModel[int]): + pass + + class FieldWrap(RootModel[str]): + model_config = ConfigDict( + json_schema_extra={"description": "parameter starts with bar_"} + ) + root: str = Field(pattern=r"^bar_.*$") + + class CustomParsed(RootModel[str]): + model_config = ConfigDict( + json_schema_extra={"description": "parameter starts with foo_"} + ) + + @field_validator("root") + @classmethod + def validator(cls, v: str) -> str: + if not v.startswith("foo_"): + raise ValueError("must start with foo_") + return v + + @model_serializer + def serialize(self): + return ( + self.root + if self.root.endswith("_serialized") + else f"{self.root}_serialized" + ) + + def parse(self): + return self.root[len("foo_") :] # removeprefix + + class DictWrap(RootModel[Dict[str, int]]): + pass +else: + from pydantic import Field, validator + + class Basic(BaseModel): + __root__: int + + class FieldWrap(BaseModel): + class Config: + schema_extra = {"description": "parameter starts with bar_"} + + __root__: str = Field(regex=r"^bar_.*$") + + class CustomParsed(BaseModel): + class Config: + schema_extra = {"description": "parameter starts with foo_"} + + __root__: str + + @validator("__root__") + @classmethod + def validator(cls, v: str) -> str: + if not v.startswith("foo_"): + raise ValueError("must start with foo_") + return v + + def dict(self, **kwargs: Any) -> Dict[str, Any]: + return { + "__root__": self.__root__ + if self.__root__.endswith("_serialized") + else f"{self.__root__}_serialized" + } + + def parse(self): + return self.__root__[len("foo_") :] # removeprefix + + class DictWrap(BaseModel): + __root__: Dict[str, int] + + +@app.get("/query/basic") +def query_basic(q: Basic = Query()): + return {"q": q} + + +@app.get("/query/fieldwrap") +def query_fieldwrap(q: FieldWrap = Query()): + return {"q": q} + + +@app.get("/query/customparsed") +def query_customparsed(q: CustomParsed = Query()): + return {"q": q.parse()} + + +@app.get("/path/basic/{p}") +def path_basic(p: Basic = Path()): + return {"p": p} + + +@app.get("/path/fieldwrap/{p}") +def path_fieldwrap(p: FieldWrap = Path()): + return {"p": p} + + +@app.get("/path/customparsed/{p}") +def path_customparsed(p: CustomParsed = Path()): + return {"p": p.parse()} + + +@app.post("/body/basic") +def body_basic(b: Basic = Body()): + return {"b": b} + + +@app.post("/body/fieldwrap") +def body_fieldwrap(b: FieldWrap = Body()): + return {"b": b} + + +@app.post("/body/customparsed") +def body_customparsed(b: CustomParsed = Body()): + return {"b": b.parse()} + + +@app.get("/echo/basic") +def echo_basic(q: Basic = Query()) -> Basic: + return q + + +@app.get("/echo/fieldwrap") +def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: + return q + + +@app.get("/echo/customparsed") +def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: + return q + + +@app.post("/echo/dictwrap") +def echo_dictwrap(b: DictWrap) -> DictWrap: + return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"] + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "url, response_json, request_body", + [ + ("/query/basic?q=42", {"q": 42}, None), + ("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None), + ("/query/customparsed?q=foo_bar", {"q": "bar"}, None), + ("/path/basic/42", {"p": 42}, None), + ("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None), + ("/path/customparsed/foo_bar", {"p": "bar"}, None), + ("/body/basic", {"b": 42}, "42"), + ("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), + ("/body/customparsed", {"b": "bar"}, "foo_bar"), + ("/echo/basic?q=42", 42, None), + ("/echo/fieldwrap?q=bar_baz", "bar_baz", None), + ("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), + ("/echo/dictwrap", {"test": 1}, {"test": 1}), + ], +) +def test_root_model_200(url: str, response_json: Any, request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 200, response.text + assert response.json() == response_json + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/basic?q=my_bad_not_int", ["query", "q"], None), + ("/path/basic/my_bad_not_int", ["path", "p"], None), + ("/body/basic", ["body"], "my_bad_not_int"), + ], +) +def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": error_path, + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "my_bad_not_int", + "url": match_pydantic_error_url("int_parsing"), + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + ) + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), + ("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), + ("/body/fieldwrap", ["body"], "my_bad_prefix_val"), + ], +) +def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "string_pattern_mismatch", + "loc": error_path, + "msg": "String should match pattern '^bar_.*$'", + "input": "my_bad_prefix_val", + "url": match_pydantic_error_url("string_pattern_mismatch"), + "ctx": {"pattern": "^bar_.*$"}, + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": 'string does not match regex "^bar_.*$"', + "type": "value_error.str.regex", + "ctx": {"pattern": "^bar_.*$"}, + } + ] + } + ) + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), + ("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), + ("/body/customparsed", ["body"], "my_bad_prefix_val"), + ], +) +def test_root_model_customparsed_422( + url: str, error_path: List[str], request_body: Any +): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "value_error", + "loc": error_path, + "msg": "Value error, must start with foo_", + "input": "my_bad_prefix_val", + "ctx": {"error": {}}, + "url": match_pydantic_error_url("value_error"), + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": "must start with foo_", + "type": "value_error", + } + ] + } + ) + + +def test_root_model_dictwrap_422(): + response = client.post("/echo/dictwrap", json={"test": "fail_not_int"}) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": ["body", "test"], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "fail_not_int", + "url": match_pydantic_error_url("int_parsing"), + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": ["body", "__root__", "test"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + ) + + +@pytest.mark.parametrize( + "model, model_schema", + [ + (Basic, {"title": "Basic", "type": "integer"}), + ( + FieldWrap, + { + "title": "FieldWrap", + "type": "string", + "pattern": "^bar_.*$", + "description": "parameter starts with bar_", + }, + ), + ( + CustomParsed, + { + "title": "CustomParsed", + "type": "string", + "description": "parameter starts with foo_", + }, + ), + ], +) +def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + paths = response.json()["paths"] + ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}} + assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [ + {"in": "query", "name": "q", "required": True, **ref} + ] + assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [ + {"in": "path", "name": "p", "required": True, **ref} + ] + assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } + assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { + "content": {"application/json": ref}, + "description": "Successful Response", + } + assert response.json()["components"]["schemas"][model.__name__] == { + "title": model.__name__, + **model_schema, + } + + +def test_openapi_schema_dictwrap(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + operation = response.json()["paths"]["/echo/dictwrap"]["post"] + ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}} + assert operation["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } + assert operation["responses"]["200"] == { + "content": {"application/json": ref}, + "description": "Successful Response", + } + assert response.json()["components"]["schemas"]["DictWrap"] == { + "title": "DictWrap", + "type": "object", + "additionalProperties": {"type": "integer"}, + } From 320084c60976784481150584e2a97a0b9212884f Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Mon, 18 Mar 2024 09:03:11 +0200 Subject: [PATCH 2/4] Root model by default again as body --- fastapi/_compat.py | 20 +++++------ fastapi/dependencies/utils.py | 7 +++- tests/test_root_model.py | 67 ++++++++++++++++++++++++++++++++++- 3 files changed, 82 insertions(+), 12 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 8f25244e9..bf9e46e68 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -239,11 +239,7 @@ if PYDANTIC_V2: return field_mapping, definitions # type: ignore[return-value] def is_scalar_field(field: ModelField) -> bool: - from fastapi import params - - return field_annotation_is_scalar( - field.field_info.annotation - ) and not isinstance(field.field_info, params.Body) + return field_annotation_is_scalar(field.field_info.annotation) def is_sequence_field(field: ModelField) -> bool: return field_annotation_is_sequence(field.field_info.annotation) @@ -505,11 +501,7 @@ else: ) def is_scalar_field(field: ModelField) -> bool: - from fastapi import params - - return is_pv1_scalar_field(field) and not isinstance( - field.field_info, params.Body - ) + return is_pv1_scalar_field(field) def is_sequence_field(field: ModelField) -> bool: return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] @@ -595,6 +587,14 @@ def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: ) +def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + return any(field_annotation_is_root_model(arg) for arg in get_args(annotation)) + + return root_model_inner_type(annotation) is not None + + def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: origin = get_origin(annotation) if origin is Union or origin is UnionType: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..359004601 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -30,6 +30,7 @@ from fastapi._compat import ( copy_field_info, create_body_model, evaluate_forwardref, + field_annotation_is_root_model, field_annotation_is_scalar, get_annotation_from_field_info, get_cached_model_fields, @@ -454,7 +455,11 @@ def analyze_param( type_annotation ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=use_annotation, default=default_value) - elif not field_annotation_is_scalar(annotation=type_annotation): + elif ( + not field_annotation_is_scalar(type_annotation) + # Root models by default regarded as bodies + or field_annotation_is_root_model(type_annotation) + ): field_info = params.Body(annotation=use_annotation, default=default_value) else: field_info = params.Query(annotation=use_annotation, default=default_value) diff --git a/tests/test_root_model.py b/tests/test_root_model.py index b3d2acc6d..e2ba6e028 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Type, Union import pytest from dirty_equals import IsDict @@ -131,6 +131,21 @@ def body_customparsed(b: CustomParsed = Body()): return {"b": b.parse()} +@app.post("/body_default/basic") +def body_default_basic(b: Basic): + return {"b": b} + + +@app.post("/body_default/fieldwrap") +def body_default_fieldwrap(b: FieldWrap): + return {"b": b} + + +@app.post("/body_default/customparsed") +def body_default_customparsed(b: CustomParsed): + return {"b": b.parse()} + + @app.get("/echo/basic") def echo_basic(q: Basic = Query()) -> Basic: return q @@ -166,6 +181,9 @@ client = TestClient(app) ("/body/basic", {"b": 42}, "42"), ("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), ("/body/customparsed", {"b": "bar"}, "foo_bar"), + ("/body_default/basic", {"b": 42}, "42"), + ("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"), + ("/body_default/customparsed", {"b": "bar"}, "foo_bar"), ("/echo/basic?q=42", 42, None), ("/echo/fieldwrap?q=bar_baz", "bar_baz", None), ("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), @@ -178,12 +196,53 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any): assert response.json() == response_json +def test_root_model_union(): + if PYDANTIC_V2: + from pydantic import RootModel + + RootModelInt = RootModel[int] + RootModelStr = RootModel[str] + else: + + class RootModelInt(BaseModel): + __root__: int + + class RootModelStr(BaseModel): + __root__: str + + app2 = FastAPI() + + @app2.post("/union") + def union_handler(b: Union[RootModelInt, RootModelStr]): + return {"b": b} + + client2 = TestClient(app2) + for body in [42, "foo"]: + response = client2.post("/union", json=body) + assert response.status_code == 200, response.text + assert response.json() == {"b": body} + + response = client2.post("/union", json=["bad_list"]) + assert response.status_code == 422, response.text + if PYDANTIC_V2: + assert {detail["msg"] for detail in response.json()["detail"]} == { + "Input should be a valid integer", + "Input should be a valid string", + } + else: + assert {detail["msg"] for detail in response.json()["detail"]} == { + "value is not a valid integer", + "str type expected", + } + + @pytest.mark.parametrize( "url, error_path, request_body", [ ("/query/basic?q=my_bad_not_int", ["query", "q"], None), ("/path/basic/my_bad_not_int", ["path", "p"], None), ("/body/basic", ["body"], "my_bad_not_int"), + ("/body_default/basic", ["body"], "my_bad_not_int"), ], ) def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): @@ -221,6 +280,7 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any ("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), ("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), ("/body/fieldwrap", ["body"], "my_bad_prefix_val"), + ("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"), ], ) def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): @@ -260,6 +320,7 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: ("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), ("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), ("/body/customparsed", ["body"], "my_bad_prefix_val"), + ("/body_default/customparsed", ["body"], "my_bad_prefix_val"), ], ) def test_root_model_customparsed_422( @@ -362,6 +423,10 @@ def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): "content": {"application/json": ref}, "required": True, } + assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { "content": {"application/json": ref}, "description": "Successful Response", From 65f1ba05e56bf9068bcebbc9de2cc25d57d9aef0 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Wed, 24 Apr 2024 20:58:14 +0300 Subject: [PATCH 3/4] Add PYDANTIC_V2 check to __root__ check. Use type alias in basic test --- fastapi/encoders.py | 8 +++++-- tests/test_root_model.py | 52 +++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/fastapi/encoders.py b/fastapi/encoders.py index aa5e04fc6..bc761a0d5 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -230,10 +230,14 @@ def jsonable_encoder( exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) + if ( - isinstance(serialized, dict) and "__root__" in serialized - ): # TODO: remove when deprecating Pydantic v1 + not PYDANTIC_V2 + and isinstance(serialized, dict) + and "__root__" in serialized + ): serialized = serialized["__root__"] + return jsonable_encoder( serialized, exclude_none=exclude_none, diff --git a/tests/test_root_model.py b/tests/test_root_model.py index e2ba6e028..5c91bc10a 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -5,7 +5,6 @@ from dirty_equals import IsDict from fastapi import Body, FastAPI, Path, Query from fastapi._compat import PYDANTIC_V2 from fastapi.testclient import TestClient -from fastapi.utils import match_pydantic_error_url from pydantic import BaseModel app = FastAPI() @@ -13,8 +12,7 @@ app = FastAPI() if PYDANTIC_V2: from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer - class Basic(RootModel[int]): - pass + Basic = RootModel[int] class FieldWrap(RootModel[str]): model_config = ConfigDict( @@ -256,7 +254,6 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any "loc": error_path, "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "my_bad_not_int", - "url": match_pydantic_error_url("int_parsing"), } ] } @@ -294,7 +291,6 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: "loc": error_path, "msg": "String should match pattern '^bar_.*$'", "input": "my_bad_prefix_val", - "url": match_pydantic_error_url("string_pattern_mismatch"), "ctx": {"pattern": "^bar_.*$"}, } ] @@ -337,7 +333,6 @@ def test_root_model_customparsed_422( "msg": "Value error, must start with foo_", "input": "my_bad_prefix_val", "ctx": {"error": {}}, - "url": match_pydantic_error_url("value_error"), } ] } @@ -366,7 +361,6 @@ def test_root_model_dictwrap_422(): "loc": ["body", "test"], "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "fail_not_int", - "url": match_pydantic_error_url("int_parsing"), } ] } @@ -385,13 +379,13 @@ def test_root_model_dictwrap_422(): @pytest.mark.parametrize( - "model, model_schema", + "model, path_name, expected_model_schema", [ - (Basic, {"title": "Basic", "type": "integer"}), + (Basic, "basic", {"type": "integer"}), ( FieldWrap, + "fieldwrap", { - "title": "FieldWrap", "type": "string", "pattern": "^bar_.*$", "description": "parameter starts with bar_", @@ -399,42 +393,46 @@ def test_root_model_dictwrap_422(): ), ( CustomParsed, + "customparsed", { - "title": "CustomParsed", "type": "string", "description": "parameter starts with foo_", }, ), ], ) -def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): +def test_openapi_schema( + model: Type, path_name: str, expected_model_schema: Dict[str, Any] +): response = client.get("/openapi.json") assert response.status_code == 200, response.text paths = response.json()["paths"] - ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}} - assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [ - {"in": "query", "name": "q", "required": True, **ref} + ref_name = model.__name__.replace("[", "_").replace("]", "_") + schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}} + + assert paths[f"/query/{path_name}"]["get"]["parameters"] == [ + {"in": "query", "name": "q", "required": True, **schema_ref} ] - assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [ - {"in": "path", "name": "p", "required": True, **ref} + assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [ + {"in": "path", "name": "p", "required": True, **schema_ref} ] - assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == { - "content": {"application/json": ref}, + assert paths[f"/body/{path_name}"]["post"]["requestBody"] == { + "content": {"application/json": schema_ref}, "required": True, } - assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == { - "content": {"application/json": ref}, + assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == { + "content": {"application/json": schema_ref}, "required": True, } - assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { - "content": {"application/json": ref}, + assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == { + "content": {"application/json": schema_ref}, "description": "Successful Response", } - assert response.json()["components"]["schemas"][model.__name__] == { - "title": model.__name__, - **model_schema, - } + + model_schema = response.json()["components"]["schemas"][ref_name] + model_schema.pop("title") + assert model_schema == expected_model_schema def test_openapi_schema_dictwrap(): From c4daa7abab6b2e503115481e4198866efbb6afee Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Sat, 19 Jul 2025 22:17:58 +0300 Subject: [PATCH 4/4] Handle get_cached_model_fields. Fix type checks. Add more tests --- fastapi/_compat.py | 43 +++++++-- fastapi/dependencies/utils.py | 12 ++- tests/test_root_model.py | 167 +++++++++++++++++++++++----------- 3 files changed, 156 insertions(+), 66 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index bf9e46e68..9fc2cc3dd 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -239,7 +239,11 @@ if PYDANTIC_V2: return field_mapping, definitions # type: ignore[return-value] def is_scalar_field(field: ModelField) -> bool: - return field_annotation_is_scalar(field.field_info.annotation) + from fastapi import params + + return field_annotation_is_scalar( + field.field_info.annotation + ) and not isinstance(field.field_info, params.Body) def is_sequence_field(field: ModelField) -> bool: return field_annotation_is_sequence(field.field_info.annotation) @@ -402,12 +406,6 @@ else: def is_pv1_scalar_field(field: ModelField) -> bool: from fastapi import params - if ( - lenient_issubclass(field.type_, BaseModel) - and "__root__" in field.type_.__fields__ - ): - return is_pv1_scalar_field(field.type_.__fields__["__root__"]) - field_info = field.field_info if not ( field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] @@ -501,12 +499,21 @@ else: ) def is_scalar_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_scalar(inner) + return is_pv1_scalar_field(field) def is_sequence_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_sequence(inner) + return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] def is_scalar_sequence_field(field: ModelField) -> bool: + if (inner := root_model_inner_type(field.type_)) is not None: + return field_annotation_is_scalar_sequence(inner) + return is_pv1_scalar_sequence_field(field) def is_bytes_field(field: ModelField) -> bool: @@ -564,6 +571,9 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_sequence(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: for arg in get_args(annotation): @@ -595,14 +605,21 @@ def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool: return root_model_inner_type(annotation) is not None +def field_annotation_has_submodel_fields(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_has_submodel_fields(inner) + + return lenient_issubclass(annotation, BaseModel) + + def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_complex(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) - if inner := root_model_inner_type(annotation): - return field_annotation_is_complex(inner) - return ( _annotation_is_complex(annotation) or _annotation_is_complex(origin) @@ -612,11 +629,17 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_scalar(annotation: Any) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_scalar(inner) + # handle Ellipsis here to make tuple[int, ...] work nicely return annotation is Ellipsis or not field_annotation_is_complex(annotation) def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: + if (inner := root_model_inner_type(annotation)) is not None: + return field_annotation_is_scalar_sequence(inner) + origin = get_origin(annotation) if origin is Union or origin is UnionType: at_least_one_scalar_sequence = False diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 359004601..8cef138fd 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -30,6 +30,7 @@ from fastapi._compat import ( copy_field_info, create_body_model, evaluate_forwardref, + field_annotation_has_submodel_fields, field_annotation_is_root_model, field_annotation_is_scalar, get_annotation_from_field_info, @@ -214,7 +215,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: if not fields: return fields first_field = fields[0] - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): + if len(fields) == 1 and field_annotation_has_submodel_fields(first_field.type_): fields_to_extract = get_cached_model_fields(first_field.type_) return fields_to_extract return fields @@ -757,8 +758,9 @@ def request_params_to_args( single_not_embedded_field = False default_convert_underscores = True if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True + if field_annotation_has_submodel_fields(first_field.type_): + fields_to_extract = get_cached_model_fields(first_field.type_) + single_not_embedded_field = True # If headers are in a Pydantic model, the way to disable convert_underscores # would be with Header(convert_underscores=False) at the Pydantic model level default_convert_underscores = getattr( @@ -921,7 +923,9 @@ async def request_body_to_args( fields_to_extract: List[ModelField] = body_fields - if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel): + if single_not_embedded_field and field_annotation_has_submodel_fields( + first_field.type_ + ): fields_to_extract = get_cached_model_fields(first_field.type_) if isinstance(received_body, FormData): diff --git a/tests/test_root_model.py b/tests/test_root_model.py index 5c91bc10a..78a568e10 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -1,18 +1,42 @@ -from typing import Any, Dict, List, Type, Union +from datetime import date +from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union import pytest from dirty_equals import IsDict -from fastapi import Body, FastAPI, Path, Query +from fastapi import Body, FastAPI, Header, Path, Query from fastapi._compat import PYDANTIC_V2 from fastapi.testclient import TestClient -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from .utils import needs_pydanticv2 app = FastAPI() if PYDANTIC_V2: - from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer + from pydantic import RootModel +else: + from pydantic import BaseModel + + T = TypeVar("T") + + class RootModel(Generic[T]): + def __class_getitem__(cls, type_): + class Inner(BaseModel): + __root__: type_ + + return Inner + + +Basic = RootModel[int] + + +class DictWrap(RootModel[Dict[str, int]]): + pass - Basic = RootModel[int] + +if PYDANTIC_V2: + from pydantic import ConfigDict, Field, field_validator, model_serializer class FieldWrap(RootModel[str]): model_config = ConfigDict( @@ -42,14 +66,8 @@ if PYDANTIC_V2: def parse(self): return self.root[len("foo_") :] # removeprefix - - class DictWrap(RootModel[Dict[str, int]]): - pass else: - from pydantic import Field, validator - - class Basic(BaseModel): - __root__: int + from pydantic import validator class FieldWrap(BaseModel): class Config: @@ -80,52 +98,49 @@ else: def parse(self): return self.__root__[len("foo_") :] # removeprefix - class DictWrap(BaseModel): - __root__: Dict[str, int] - @app.get("/query/basic") -def query_basic(q: Basic = Query()): +def query_basic(q: Annotated[Basic, Query()]): return {"q": q} @app.get("/query/fieldwrap") -def query_fieldwrap(q: FieldWrap = Query()): +def query_fieldwrap(q: Annotated[FieldWrap, Query()]): return {"q": q} @app.get("/query/customparsed") -def query_customparsed(q: CustomParsed = Query()): +def query_customparsed(q: Annotated[CustomParsed, Query()]): return {"q": q.parse()} @app.get("/path/basic/{p}") -def path_basic(p: Basic = Path()): +def path_basic(p: Annotated[Basic, Path()]): return {"p": p} @app.get("/path/fieldwrap/{p}") -def path_fieldwrap(p: FieldWrap = Path()): +def path_fieldwrap(p: Annotated[FieldWrap, Path()]): return {"p": p} @app.get("/path/customparsed/{p}") -def path_customparsed(p: CustomParsed = Path()): +def path_customparsed(p: Annotated[CustomParsed, Path()]): return {"p": p.parse()} @app.post("/body/basic") -def body_basic(b: Basic = Body()): +def body_basic(b: Annotated[Basic, Body()]): return {"b": b} @app.post("/body/fieldwrap") -def body_fieldwrap(b: FieldWrap = Body()): +def body_fieldwrap(b: Annotated[FieldWrap, Body()]): return {"b": b} @app.post("/body/customparsed") -def body_customparsed(b: CustomParsed = Body()): +def body_customparsed(b: Annotated[CustomParsed, Body()]): return {"b": b.parse()} @@ -145,17 +160,17 @@ def body_default_customparsed(b: CustomParsed): @app.get("/echo/basic") -def echo_basic(q: Basic = Query()) -> Basic: +def echo_basic(q: Annotated[Basic, Query()]) -> Basic: return q @app.get("/echo/fieldwrap") -def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: +def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap: return q @app.get("/echo/customparsed") -def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: +def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed: return q @@ -194,44 +209,92 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any): assert response.json() == response_json -def test_root_model_union(): - if PYDANTIC_V2: - from pydantic import RootModel +@pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]]) +@pytest.mark.parametrize("outer", [None, List, Sequence]) +def test2_root_model_200__basic(type_: Type, outer: Optional[Type]): + inner = outer[type_] if outer else type_ + Model = RootModel[inner] + + app2 = FastAPI() + + if outer is None: + + @app2.get("/path/{p}") + def path_basic2(p: Annotated[Model, Path()]) -> Model: + return p + else: + with pytest.raises( + AssertionError, match="Path params must be of one of the supported types" + ): + + @app2.get("/path/{p}") + def path_basic2(p: Annotated[Model, Path()]) -> Model: + return p # pragma: nocover + + @app2.get("/query") + def query_basic2(q: Annotated[Model, Query()]) -> Model: + return q + + @app2.get("/header") + def path_basic2(h: Annotated[Model, Header()]) -> Model: + return h - RootModelInt = RootModel[int] - RootModelStr = RootModel[str] + @app2.post("/body") + def body_basic2(b: Annotated[Model, Body()]) -> Model: + return b + + client2 = TestClient(app2) + + if outer is None: + expected = "42" if type_ in (str, bytes) else 42 else: + expected = ["42", "43"] if type_ in (str, bytes) else [42, 43] + + if outer is None: + assert client2.get("/path/42").json() == expected + assert client2.get("/query?q=42").json() == expected + assert client2.get("/header", headers={"h": "42"}).json() == expected + else: + assert client2.get("/path/42").json()["detail"] == "Not Found" + assert client2.get("/query?q=42&q=43").json() == expected + assert ( + client2.get("/header", headers=[("h", "42"), ("h", "43")]).json() + == expected + ) + assert client2.post("/body", json=expected).json() == expected - class RootModelInt(BaseModel): - __root__: int - class RootModelStr(BaseModel): - __root__: str +@pytest.mark.parametrize("left", [RootModel[date], date]) +@pytest.mark.parametrize("right", [RootModel[int], int]) +@needs_pydanticv2 +def test_root_model_union(left: Any, right: Any): + Model = Union[left, right] app2 = FastAPI() - @app2.post("/union") - def union_handler(b: Union[RootModelInt, RootModelStr]): + @app2.get("/query") + def handler1(q: Annotated[Model, Query()]): + return {"q": q} + + @app2.post("/body") + def handler2(b: Annotated[Model, Body()]): return {"b": b} client2 = TestClient(app2) - for body in [42, "foo"]: - response = client2.post("/union", json=body) + for val in ["2025-01-02", 42]: + response = client2.get(f"/query?q={val}") + assert response.status_code == 200, response.text + assert response.json() == {"q": val} + response = client2.post("/body", json=val) assert response.status_code == 200, response.text - assert response.json() == {"b": body} + assert response.json() == {"b": val} - response = client2.post("/union", json=["bad_list"]) + response = client2.get("/query?q=bad") assert response.status_code == 422, response.text - if PYDANTIC_V2: - assert {detail["msg"] for detail in response.json()["detail"]} == { - "Input should be a valid integer", - "Input should be a valid string", - } - else: - assert {detail["msg"] for detail in response.json()["detail"]} == { - "value is not a valid integer", - "str type expected", - } + assert {detail["msg"] for detail in response.json()["detail"]} == { + "Input should be a valid date or datetime, input is too short", + "Input should be a valid integer, unable to parse string as an integer", + } @pytest.mark.parametrize(