From 867d75b5c8d8f8dc7e8ebcbf20bd658a71e41331 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Sun, 17 Mar 2024 15:58:31 +0200 Subject: [PATCH] Support Pydantic root model as query parameter --- fastapi/_compat.py | 30 ++- fastapi/encoders.py | 10 +- tests/test_root_model.py | 393 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 427 insertions(+), 6 deletions(-) create mode 100644 tests/test_root_model.py diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 227ad837d..8f25244e9 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -50,7 +50,7 @@ Url: Type[Any] if PYDANTIC_V2: from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError - from pydantic import TypeAdapter + from pydantic import RootModel, TypeAdapter from pydantic import ValidationError as ValidationError from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] GetJsonSchemaHandler as GetJsonSchemaHandler, @@ -292,6 +292,11 @@ if PYDANTIC_V2: for name, field_info in model.model_fields.items() ] + def root_model_inner_type(annotation: Any) -> Any: + if lenient_issubclass(annotation, RootModel): + return annotation.model_fields["root"].annotation + return None + else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 @@ -401,6 +406,12 @@ else: def is_pv1_scalar_field(field: ModelField) -> bool: from fastapi import params + if ( + lenient_issubclass(field.type_, BaseModel) + and "__root__" in field.type_.__fields__ + ): + return is_pv1_scalar_field(field.type_.__fields__["__root__"]) + field_info = field.field_info if not ( field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] @@ -494,7 +505,11 @@ else: ) def is_scalar_field(field: ModelField) -> bool: - return is_pv1_scalar_field(field) + from fastapi import params + + return is_pv1_scalar_field(field) and not isinstance( + field.field_info, params.Body + ) def is_sequence_field(field: ModelField) -> bool: return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] @@ -530,6 +545,14 @@ else: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: return list(model.__fields__.values()) # type: ignore[attr-defined] + def root_model_inner_type(annotation: Any) -> Any: + if ( + lenient_issubclass(annotation, BaseModel) + and "__root__" in annotation.__fields__ + ): + return annotation.__fields__["__root__"].annotation + return None + def _regenerate_error_with_loc( *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] @@ -577,6 +600,9 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: if origin is Union or origin is UnionType: return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) + if inner := root_model_inner_type(annotation): + return field_annotation_is_complex(inner) + return ( _annotation_is_complex(annotation) or _annotation_is_complex(origin) diff --git a/fastapi/encoders.py b/fastapi/encoders.py index 451ea0760..aa5e04fc6 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -220,7 +220,7 @@ def jsonable_encoder( encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: encoders.update(custom_encoder) - obj_dict = _model_dump( + serialized = _model_dump( obj, mode="json", include=include, @@ -230,10 +230,12 @@ def jsonable_encoder( exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] + if ( + isinstance(serialized, dict) and "__root__" in serialized + ): # TODO: remove when deprecating Pydantic v1 + serialized = serialized["__root__"] return jsonable_encoder( - obj_dict, + serialized, exclude_none=exclude_none, exclude_defaults=exclude_defaults, # TODO: remove when deprecating Pydantic v1 diff --git a/tests/test_root_model.py b/tests/test_root_model.py new file mode 100644 index 000000000..b3d2acc6d --- /dev/null +++ b/tests/test_root_model.py @@ -0,0 +1,393 @@ +from typing import Any, Dict, List, Type + +import pytest +from dirty_equals import IsDict +from fastapi import Body, FastAPI, Path, Query +from fastapi._compat import PYDANTIC_V2 +from fastapi.testclient import TestClient +from fastapi.utils import match_pydantic_error_url +from pydantic import BaseModel + +app = FastAPI() + +if PYDANTIC_V2: + from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer + + class Basic(RootModel[int]): + pass + + class FieldWrap(RootModel[str]): + model_config = ConfigDict( + json_schema_extra={"description": "parameter starts with bar_"} + ) + root: str = Field(pattern=r"^bar_.*$") + + class CustomParsed(RootModel[str]): + model_config = ConfigDict( + json_schema_extra={"description": "parameter starts with foo_"} + ) + + @field_validator("root") + @classmethod + def validator(cls, v: str) -> str: + if not v.startswith("foo_"): + raise ValueError("must start with foo_") + return v + + @model_serializer + def serialize(self): + return ( + self.root + if self.root.endswith("_serialized") + else f"{self.root}_serialized" + ) + + def parse(self): + return self.root[len("foo_") :] # removeprefix + + class DictWrap(RootModel[Dict[str, int]]): + pass +else: + from pydantic import Field, validator + + class Basic(BaseModel): + __root__: int + + class FieldWrap(BaseModel): + class Config: + schema_extra = {"description": "parameter starts with bar_"} + + __root__: str = Field(regex=r"^bar_.*$") + + class CustomParsed(BaseModel): + class Config: + schema_extra = {"description": "parameter starts with foo_"} + + __root__: str + + @validator("__root__") + @classmethod + def validator(cls, v: str) -> str: + if not v.startswith("foo_"): + raise ValueError("must start with foo_") + return v + + def dict(self, **kwargs: Any) -> Dict[str, Any]: + return { + "__root__": self.__root__ + if self.__root__.endswith("_serialized") + else f"{self.__root__}_serialized" + } + + def parse(self): + return self.__root__[len("foo_") :] # removeprefix + + class DictWrap(BaseModel): + __root__: Dict[str, int] + + +@app.get("/query/basic") +def query_basic(q: Basic = Query()): + return {"q": q} + + +@app.get("/query/fieldwrap") +def query_fieldwrap(q: FieldWrap = Query()): + return {"q": q} + + +@app.get("/query/customparsed") +def query_customparsed(q: CustomParsed = Query()): + return {"q": q.parse()} + + +@app.get("/path/basic/{p}") +def path_basic(p: Basic = Path()): + return {"p": p} + + +@app.get("/path/fieldwrap/{p}") +def path_fieldwrap(p: FieldWrap = Path()): + return {"p": p} + + +@app.get("/path/customparsed/{p}") +def path_customparsed(p: CustomParsed = Path()): + return {"p": p.parse()} + + +@app.post("/body/basic") +def body_basic(b: Basic = Body()): + return {"b": b} + + +@app.post("/body/fieldwrap") +def body_fieldwrap(b: FieldWrap = Body()): + return {"b": b} + + +@app.post("/body/customparsed") +def body_customparsed(b: CustomParsed = Body()): + return {"b": b.parse()} + + +@app.get("/echo/basic") +def echo_basic(q: Basic = Query()) -> Basic: + return q + + +@app.get("/echo/fieldwrap") +def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap: + return q + + +@app.get("/echo/customparsed") +def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed: + return q + + +@app.post("/echo/dictwrap") +def echo_dictwrap(b: DictWrap) -> DictWrap: + return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"] + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "url, response_json, request_body", + [ + ("/query/basic?q=42", {"q": 42}, None), + ("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None), + ("/query/customparsed?q=foo_bar", {"q": "bar"}, None), + ("/path/basic/42", {"p": 42}, None), + ("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None), + ("/path/customparsed/foo_bar", {"p": "bar"}, None), + ("/body/basic", {"b": 42}, "42"), + ("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), + ("/body/customparsed", {"b": "bar"}, "foo_bar"), + ("/echo/basic?q=42", 42, None), + ("/echo/fieldwrap?q=bar_baz", "bar_baz", None), + ("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), + ("/echo/dictwrap", {"test": 1}, {"test": 1}), + ], +) +def test_root_model_200(url: str, response_json: Any, request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 200, response.text + assert response.json() == response_json + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/basic?q=my_bad_not_int", ["query", "q"], None), + ("/path/basic/my_bad_not_int", ["path", "p"], None), + ("/body/basic", ["body"], "my_bad_not_int"), + ], +) +def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": error_path, + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "my_bad_not_int", + "url": match_pydantic_error_url("int_parsing"), + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + ) + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), + ("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), + ("/body/fieldwrap", ["body"], "my_bad_prefix_val"), + ], +) +def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "string_pattern_mismatch", + "loc": error_path, + "msg": "String should match pattern '^bar_.*$'", + "input": "my_bad_prefix_val", + "url": match_pydantic_error_url("string_pattern_mismatch"), + "ctx": {"pattern": "^bar_.*$"}, + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": 'string does not match regex "^bar_.*$"', + "type": "value_error.str.regex", + "ctx": {"pattern": "^bar_.*$"}, + } + ] + } + ) + + +@pytest.mark.parametrize( + "url, error_path, request_body", + [ + ("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), + ("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), + ("/body/customparsed", ["body"], "my_bad_prefix_val"), + ], +) +def test_root_model_customparsed_422( + url: str, error_path: List[str], request_body: Any +): + response = client.post(url, json=request_body) if request_body else client.get(url) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "value_error", + "loc": error_path, + "msg": "Value error, must start with foo_", + "input": "my_bad_prefix_val", + "ctx": {"error": {}}, + "url": match_pydantic_error_url("value_error"), + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": [*error_path, "__root__"], + "msg": "must start with foo_", + "type": "value_error", + } + ] + } + ) + + +def test_root_model_dictwrap_422(): + response = client.post("/echo/dictwrap", json={"test": "fail_not_int"}) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": ["body", "test"], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "fail_not_int", + "url": match_pydantic_error_url("int_parsing"), + } + ] + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": ["body", "__root__", "test"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + ) + + +@pytest.mark.parametrize( + "model, model_schema", + [ + (Basic, {"title": "Basic", "type": "integer"}), + ( + FieldWrap, + { + "title": "FieldWrap", + "type": "string", + "pattern": "^bar_.*$", + "description": "parameter starts with bar_", + }, + ), + ( + CustomParsed, + { + "title": "CustomParsed", + "type": "string", + "description": "parameter starts with foo_", + }, + ), + ], +) +def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + paths = response.json()["paths"] + ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}} + assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [ + {"in": "query", "name": "q", "required": True, **ref} + ] + assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [ + {"in": "path", "name": "p", "required": True, **ref} + ] + assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } + assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { + "content": {"application/json": ref}, + "description": "Successful Response", + } + assert response.json()["components"]["schemas"][model.__name__] == { + "title": model.__name__, + **model_schema, + } + + +def test_openapi_schema_dictwrap(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + operation = response.json()["paths"]["/echo/dictwrap"]["post"] + ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}} + assert operation["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } + assert operation["responses"]["200"] == { + "content": {"application/json": ref}, + "description": "Successful Response", + } + assert response.json()["components"]["schemas"]["DictWrap"] == { + "title": "DictWrap", + "type": "object", + "additionalProperties": {"type": "integer"}, + }