Browse Source

Support Pydantic root model as query parameter

pull/11306/head
Markus Sintonen 1 year ago
parent
commit
867d75b5c8
  1. 30
      fastapi/_compat.py
  2. 10
      fastapi/encoders.py
  3. 393
      tests/test_root_model.py

30
fastapi/_compat.py

@ -50,7 +50,7 @@ Url: Type[Any]
if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
from pydantic import RootModel, TypeAdapter
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
@ -292,6 +292,11 @@ if PYDANTIC_V2:
for name, field_info in model.model_fields.items()
]
def root_model_inner_type(annotation: Any) -> Any:
if lenient_issubclass(annotation, RootModel):
return annotation.model_fields["root"].annotation
return None
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
@ -401,6 +406,12 @@ else:
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
if (
lenient_issubclass(field.type_, BaseModel)
and "__root__" in field.type_.__fields__
):
return is_pv1_scalar_field(field.type_.__fields__["__root__"])
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
@ -494,7 +505,11 @@ else:
)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
from fastapi import params
return is_pv1_scalar_field(field) and not isinstance(
field.field_info, params.Body
)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
@ -530,6 +545,14 @@ else:
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def root_model_inner_type(annotation: Any) -> Any:
if (
lenient_issubclass(annotation, BaseModel)
and "__root__" in annotation.__fields__
):
return annotation.__fields__["__root__"].annotation
return None
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@ -577,6 +600,9 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
if inner := root_model_inner_type(annotation):
return field_annotation_is_complex(inner)
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)

10
fastapi/encoders.py

@ -220,7 +220,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
serialized = _model_dump(
obj,
mode="json",
include=include,
@ -230,10 +230,12 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
if (
isinstance(serialized, dict) and "__root__" in serialized
): # TODO: remove when deprecating Pydantic v1
serialized = serialized["__root__"]
return jsonable_encoder(
obj_dict,
serialized,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1

393
tests/test_root_model.py

@ -0,0 +1,393 @@
from typing import Any, Dict, List, Type
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI, Path, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from fastapi.utils import match_pydantic_error_url
from pydantic import BaseModel
app = FastAPI()
if PYDANTIC_V2:
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer
class Basic(RootModel[int]):
pass
class FieldWrap(RootModel[str]):
model_config = ConfigDict(
json_schema_extra={"description": "parameter starts with bar_"}
)
root: str = Field(pattern=r"^bar_.*$")
class CustomParsed(RootModel[str]):
model_config = ConfigDict(
json_schema_extra={"description": "parameter starts with foo_"}
)
@field_validator("root")
@classmethod
def validator(cls, v: str) -> str:
if not v.startswith("foo_"):
raise ValueError("must start with foo_")
return v
@model_serializer
def serialize(self):
return (
self.root
if self.root.endswith("_serialized")
else f"{self.root}_serialized"
)
def parse(self):
return self.root[len("foo_") :] # removeprefix
class DictWrap(RootModel[Dict[str, int]]):
pass
else:
from pydantic import Field, validator
class Basic(BaseModel):
__root__: int
class FieldWrap(BaseModel):
class Config:
schema_extra = {"description": "parameter starts with bar_"}
__root__: str = Field(regex=r"^bar_.*$")
class CustomParsed(BaseModel):
class Config:
schema_extra = {"description": "parameter starts with foo_"}
__root__: str
@validator("__root__")
@classmethod
def validator(cls, v: str) -> str:
if not v.startswith("foo_"):
raise ValueError("must start with foo_")
return v
def dict(self, **kwargs: Any) -> Dict[str, Any]:
return {
"__root__": self.__root__
if self.__root__.endswith("_serialized")
else f"{self.__root__}_serialized"
}
def parse(self):
return self.__root__[len("foo_") :] # removeprefix
class DictWrap(BaseModel):
__root__: Dict[str, int]
@app.get("/query/basic")
def query_basic(q: Basic = Query()):
return {"q": q}
@app.get("/query/fieldwrap")
def query_fieldwrap(q: FieldWrap = Query()):
return {"q": q}
@app.get("/query/customparsed")
def query_customparsed(q: CustomParsed = Query()):
return {"q": q.parse()}
@app.get("/path/basic/{p}")
def path_basic(p: Basic = Path()):
return {"p": p}
@app.get("/path/fieldwrap/{p}")
def path_fieldwrap(p: FieldWrap = Path()):
return {"p": p}
@app.get("/path/customparsed/{p}")
def path_customparsed(p: CustomParsed = Path()):
return {"p": p.parse()}
@app.post("/body/basic")
def body_basic(b: Basic = Body()):
return {"b": b}
@app.post("/body/fieldwrap")
def body_fieldwrap(b: FieldWrap = Body()):
return {"b": b}
@app.post("/body/customparsed")
def body_customparsed(b: CustomParsed = Body()):
return {"b": b.parse()}
@app.get("/echo/basic")
def echo_basic(q: Basic = Query()) -> Basic:
return q
@app.get("/echo/fieldwrap")
def echo_fieldwrap(q: FieldWrap = Query()) -> FieldWrap:
return q
@app.get("/echo/customparsed")
def echo_customparsed(q: CustomParsed = Query()) -> CustomParsed:
return q
@app.post("/echo/dictwrap")
def echo_dictwrap(b: DictWrap) -> DictWrap:
return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"]
client = TestClient(app)
@pytest.mark.parametrize(
"url, response_json, request_body",
[
("/query/basic?q=42", {"q": 42}, None),
("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None),
("/query/customparsed?q=foo_bar", {"q": "bar"}, None),
("/path/basic/42", {"p": 42}, None),
("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None),
("/path/customparsed/foo_bar", {"p": "bar"}, None),
("/body/basic", {"b": 42}, "42"),
("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
("/body/customparsed", {"b": "bar"}, "foo_bar"),
("/echo/basic?q=42", 42, None),
("/echo/fieldwrap?q=bar_baz", "bar_baz", None),
("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None),
("/echo/dictwrap", {"test": 1}, {"test": 1}),
],
)
def test_root_model_200(url: str, response_json: Any, request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 200, response.text
assert response.json() == response_json
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/basic?q=my_bad_not_int", ["query", "q"], None),
("/path/basic/my_bad_not_int", ["path", "p"], None),
("/body/basic", ["body"], "my_bad_not_int"),
],
)
def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": error_path,
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "my_bad_not_int",
"url": match_pydantic_error_url("int_parsing"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
}
)
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None),
("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None),
("/body/fieldwrap", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_pattern_mismatch",
"loc": error_path,
"msg": "String should match pattern '^bar_.*$'",
"input": "my_bad_prefix_val",
"url": match_pydantic_error_url("string_pattern_mismatch"),
"ctx": {"pattern": "^bar_.*$"},
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": 'string does not match regex "^bar_.*$"',
"type": "value_error.str.regex",
"ctx": {"pattern": "^bar_.*$"},
}
]
}
)
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None),
("/path/customparsed/my_bad_prefix_val", ["path", "p"], None),
("/body/customparsed", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_customparsed_422(
url: str, error_path: List[str], request_body: Any
):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "value_error",
"loc": error_path,
"msg": "Value error, must start with foo_",
"input": "my_bad_prefix_val",
"ctx": {"error": {}},
"url": match_pydantic_error_url("value_error"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": "must start with foo_",
"type": "value_error",
}
]
}
)
def test_root_model_dictwrap_422():
response = client.post("/echo/dictwrap", json={"test": "fail_not_int"})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": ["body", "test"],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "fail_not_int",
"url": match_pydantic_error_url("int_parsing"),
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["body", "__root__", "test"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
}
)
@pytest.mark.parametrize(
"model, model_schema",
[
(Basic, {"title": "Basic", "type": "integer"}),
(
FieldWrap,
{
"title": "FieldWrap",
"type": "string",
"pattern": "^bar_.*$",
"description": "parameter starts with bar_",
},
),
(
CustomParsed,
{
"title": "CustomParsed",
"type": "string",
"description": "parameter starts with foo_",
},
),
],
)
def test_openapi_schema(model: Type, model_schema: Dict[str, Any]):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
paths = response.json()["paths"]
ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}}
assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [
{"in": "query", "name": "q", "required": True, **ref}
]
assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [
{"in": "path", "name": "p", "required": True, **ref}
]
assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == {
"content": {"application/json": ref},
"required": True,
}
assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == {
"content": {"application/json": ref},
"description": "Successful Response",
}
assert response.json()["components"]["schemas"][model.__name__] == {
"title": model.__name__,
**model_schema,
}
def test_openapi_schema_dictwrap():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
operation = response.json()["paths"]["/echo/dictwrap"]["post"]
ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}}
assert operation["requestBody"] == {
"content": {"application/json": ref},
"required": True,
}
assert operation["responses"]["200"] == {
"content": {"application/json": ref},
"description": "Successful Response",
}
assert response.json()["components"]["schemas"]["DictWrap"] == {
"title": "DictWrap",
"type": "object",
"additionalProperties": {"type": "integer"},
}
Loading…
Cancel
Save