Browse Source

Support Pydantic root model as query parameter

pull/11306/head
Markus Sintonen 1 year ago
parent
commit
867d75b5c8
  1. 30
      fastapi/_compat.py
  2. 10
      fastapi/encoders.py
  3. 393
      tests/test_root_model.py

30
fastapi/_compat.py

@ -50,7 +50,7 @@ Url: Type[Any]
if PYDANTIC_V2: if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter from pydantic import RootModel, TypeAdapter
from pydantic import ValidationError as ValidationError from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler, GetJsonSchemaHandler as GetJsonSchemaHandler,
@ -292,6 +292,11 @@ if PYDANTIC_V2:
for name, field_info in model.model_fields.items() for name, field_info in model.model_fields.items()
] ]
def root_model_inner_type(annotation: Any) -> Any:
if lenient_issubclass(annotation, RootModel):
return annotation.model_fields["root"].annotation
return None
else: else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401 from pydantic import AnyUrl as Url # noqa: F401
@ -401,6 +406,12 @@ else:
def is_pv1_scalar_field(field: ModelField) -> bool: def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params from fastapi import params
if (
lenient_issubclass(field.type_, BaseModel)
and "__root__" in field.type_.__fields__
):
return is_pv1_scalar_field(field.type_.__fields__["__root__"])
field_info = field.field_info field_info = field.field_info
if not ( if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
@ -494,7 +505,11 @@ else:
) )
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field) from fastapi import params
return is_pv1_scalar_field(field) and not isinstance(
field.field_info, params.Body
)
def is_sequence_field(field: ModelField) -> bool: def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
@ -530,6 +545,14 @@ else:
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined] return list(model.__fields__.values()) # type: ignore[attr-defined]
def root_model_inner_type(annotation: Any) -> Any:
if (
lenient_issubclass(annotation, BaseModel)
and "__root__" in annotation.__fields__
):
return annotation.__fields__["__root__"].annotation
return None
def _regenerate_error_with_loc( def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@ -577,6 +600,9 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
if inner := root_model_inner_type(annotation):
return field_annotation_is_complex(inner)
return ( return (
_annotation_is_complex(annotation) _annotation_is_complex(annotation)
or _annotation_is_complex(origin) or _annotation_is_complex(origin)

10
fastapi/encoders.py

@ -220,7 +220,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder: if custom_encoder:
encoders.update(custom_encoder) encoders.update(custom_encoder)
obj_dict = _model_dump( serialized = _model_dump(
obj, obj,
mode="json", mode="json",
include=include, include=include,
@ -230,10 +230,12 @@ def jsonable_encoder(
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
) )
if "__root__" in obj_dict: if (
obj_dict = obj_dict["__root__"] isinstance(serialized, dict) and "__root__" in serialized
): # TODO: remove when deprecating Pydantic v1
serialized = serialized["__root__"]
return jsonable_encoder( return jsonable_encoder(
obj_dict, serialized,
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1 # TODO: remove when deprecating Pydantic v1

393
tests/test_root_model.py

@ -0,0 +1,393 @@
from typing import Any, Dict, List, Type
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI, Path, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from fastapi.utils import match_pydantic_error_url
from pydantic import BaseModel
app = FastAPI()
if PYDANTIC_V2:
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer
class Basic(RootModel[int]):
pass
class FieldWrap(RootModel[str]):
model_config = ConfigDict(
json_schema_extra={"description": "parameter starts with bar_"}
)
root: str = Field(pattern=r"^bar_.*$")
class CustomParsed(RootModel[str]):
model_config = ConfigDict(
json_schema_extra={"description": "parameter starts with foo_"}
)
@field_validator("root")
@classmethod
def validator(cls, v: str) -> str:
if not v.startswith("foo_"):
raise ValueError("must start with foo_")
return v
@model_serializer
def serialize(self):
return (
self.root
if self.root.endswith("_serialized")
else f"{self.root}_serialized"
)
def parse(self):
return self.root[len("foo_") :] # removeprefix
class DictWrap(RootModel[Dict[str, int]]):
pass
else:
from pydantic import Field, validator
class Basic(BaseModel):
__root__: int
class FieldWrap(BaseModel):
class Config:
schema_extra = {"description": "parameter starts with bar_"}
__root__: str = Field(regex=r"^bar_.*$")
class CustomParsed(BaseModel):
class Config:
schema_extra = {"description": "parameter starts with foo_"}
__root__: str
@validator("__root__")
@classmethod
def validator(cls, v: str) -> str:
if not v.startswith("foo_"):
raise ValueError("must start with foo_")
return v
def dict(self, **kwargs: Any) -> Dict[str, Any]:
return {
"__root__": self.__root__
if self.__root__.endswith("_serialized")
else f"{self.__root__}_serialized"
}
def parse(self):
return self.__root__[len("foo_") :] # removeprefix
class DictWrap(BaseModel):
__root__: Dict[str, int]
@app.get("/query/basic")
def query_basic(q: Basic = Query()):
return {"q": q}
@app.get("/query/fieldwrap")
def query_fieldwrap(q: FieldWrap = Query()):
return {"q": q}
@app.get("/query/customparsed")
def query_customparsed(q: CustomParsed = Query()):
return {"q": q.parse()}
@app.get("/path/basic/{p}")
def path_basic(p: Basic = Path()):
return {"p": p}
@app.get("/path/fieldwrap/{p}")
def path_fieldwrap(p: FieldWrap = Path()):
return {"p": p}
@app.get("/path/customparsed/{p}")
def path_customparsed(p: CustomParsed = Path()):
return {"p": p.parse()}
@app.post("/body/basic")
def body_basic(b: Basic = Body()):
return {"b": b}
@app.post("/body/fieldwrap")
def body_fieldwrap(b: FieldWrap = Body()):
return {"b": b}
@app.post("/body/customparsed")
def body_customparsed(b: CustomParsed = Body()):
return {"b": b.parse()}
@app.get("/echo/basic")
def echo_basic(q: Basic = Query()) -> Basic:
return q
@app.get("/echo/fieldwrap")
def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap:
return q
@app.get("/echo/customparsed")
def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed:
return q
@app.post("/echo/dictwrap")
def echo_dictwrap(b: DictWrap) -> DictWrap:
return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"]
client = TestClient(app)
@pytest.mark.parametrize(
"url, response_json, request_body",
[
("/query/basic?q=42", {"q": 42}, None),
("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None),
("/query/customparsed?q=foo_bar", {"q": "bar"}, None),
("/path/basic/42", {"p": 42}, None),
("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None),
("/path/customparsed/foo_bar", {"p": "bar"}, None),
("/body/basic", {"b": 42}, "42"),
("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
("/body/customparsed", {"b": "bar"}, "foo_bar"),
("/echo/basic?q=42", 42, None),
("/echo/fieldwrap?q=bar_baz", "bar_baz", None),
("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None),
("/echo/dictwrap", {"test": 1}, {"test": 1}),
],
)
def test_root_model_200(url: str, response_json: Any, request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 200, response.text
assert response.json() == response_json
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/basic?q=my_bad_not_int", ["query", "q"], None),
("/path/basic/my_bad_not_int", ["path", "p"], None),
("/body/basic", ["body"], "my_bad_not_int"),
],
)
def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": error_path,
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "my_bad_not_int",
"url": match_pydantic_error_url("int_parsing"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
}
)
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None),
("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None),
("/body/fieldwrap", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_pattern_mismatch",
"loc": error_path,
"msg": "String should match pattern '^bar_.*$'",
"input": "my_bad_prefix_val",
"url": match_pydantic_error_url("string_pattern_mismatch"),
"ctx": {"pattern": "^bar_.*$"},
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": 'string does not match regex "^bar_.*$"',
"type": "value_error.str.regex",
"ctx": {"pattern": "^bar_.*$"},
}
]
}
)
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None),
("/path/customparsed/my_bad_prefix_val", ["path", "p"], None),
("/body/customparsed", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_customparsed_422(
url: str, error_path: List[str], request_body: Any
):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "value_error",
"loc": error_path,
"msg": "Value error, must start with foo_",
"input": "my_bad_prefix_val",
"ctx": {"error": {}},
"url": match_pydantic_error_url("value_error"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": "must start with foo_",
"type": "value_error",
}
]
}
)
def test_root_model_dictwrap_422():
response = client.post("/echo/dictwrap", json={"test": "fail_not_int"})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": ["body", "test"],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "fail_not_int",
"url": match_pydantic_error_url("int_parsing"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["body", "__root__", "test"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
}
)
@pytest.mark.parametrize(
"model, model_schema",
[
(Basic, {"title": "Basic", "type": "integer"}),
(
FieldWrap,
{
"title": "FieldWrap",
"type": "string",
"pattern": "^bar_.*$",
"description": "parameter starts with bar_",
},
),
(
CustomParsed,
{
"title": "CustomParsed",
"type": "string",
"description": "parameter starts with foo_",
},
),
],
)
def test_openapi_schema(model: Type, model_schema: Dict[str, Any]):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
paths = response.json()["paths"]
ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}}
assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [
{"in": "query", "name": "q", "required": True, **ref}
]
assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [
{"in": "path", "name": "p", "required": True, **ref}
]
assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == {
"content": {"application/json": ref},
"required": True,
}
assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == {
"content": {"application/json": ref},
"description": "Successful Response",
}
assert response.json()["components"]["schemas"][model.__name__] == {
"title": model.__name__,
**model_schema,
}
def test_openapi_schema_dictwrap():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
operation = response.json()["paths"]["/echo/dictwrap"]["post"]
ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}}
assert operation["requestBody"] == {
"content": {"application/json": ref},
"required": True,
}
assert operation["responses"]["200"] == {
"content": {"application/json": ref},
"description": "Successful Response",
}
assert response.json()["components"]["schemas"]["DictWrap"] == {
"title": "DictWrap",
"type": "object",
"additionalProperties": {"type": "integer"},
}
Loading…
Cancel
Save