Browse Source

Merge c4daa7abab into 6e69d62bfe

pull/11306/merge
MarkusSintonen 1 day ago
committed by GitHub
parent
commit
608ccbe782
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 51
      fastapi/_compat.py
  2. 19
      fastapi/dependencies/utils.py
  3. 14
      fastapi/encoders.py
  4. 519
      tests/test_root_model.py

51
fastapi/_compat.py

@ -50,7 +50,7 @@ Url: Type[Any]
if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
from pydantic import RootModel, TypeAdapter
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
@ -292,6 +292,11 @@ if PYDANTIC_V2:
for name, field_info in model.model_fields.items()
]
def root_model_inner_type(annotation: Any) -> Any:
if lenient_issubclass(annotation, RootModel):
return annotation.model_fields["root"].annotation
return None
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
@ -494,12 +499,21 @@ else:
)
def is_scalar_field(field: ModelField) -> bool:
if (inner := root_model_inner_type(field.type_)) is not None:
return field_annotation_is_scalar(inner)
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
if (inner := root_model_inner_type(field.type_)) is not None:
return field_annotation_is_sequence(inner)
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
if (inner := root_model_inner_type(field.type_)) is not None:
return field_annotation_is_scalar_sequence(inner)
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
@ -530,6 +544,14 @@ else:
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def root_model_inner_type(annotation: Any) -> Any:
if (
lenient_issubclass(annotation, BaseModel)
and "__root__" in annotation.__fields__
):
return annotation.__fields__["__root__"].annotation
return None
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@ -549,6 +571,9 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_sequence(inner)
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
@ -572,7 +597,25 @@ def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
)
def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_root_model(arg) for arg in get_args(annotation))
return root_model_inner_type(annotation) is not None
def field_annotation_has_submodel_fields(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_has_submodel_fields(inner)
return lenient_issubclass(annotation, BaseModel)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_complex(inner)
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
@ -586,11 +629,17 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar(annotation: Any) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_scalar(inner)
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
if (inner := root_model_inner_type(annotation)) is not None:
return field_annotation_is_scalar_sequence(inner)
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False

19
fastapi/dependencies/utils.py

@ -30,6 +30,8 @@ from fastapi._compat import (
copy_field_info,
create_body_model,
evaluate_forwardref,
field_annotation_has_submodel_fields,
field_annotation_is_root_model,
field_annotation_is_scalar,
get_annotation_from_field_info,
get_cached_model_fields,
@ -213,7 +215,7 @@ def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
if not fields:
return fields
first_field = fields[0]
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
if len(fields) == 1 and field_annotation_has_submodel_fields(first_field.type_):
fields_to_extract = get_cached_model_fields(first_field.type_)
return fields_to_extract
return fields
@ -454,7 +456,11 @@ def analyze_param(
type_annotation
) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation):
elif (
not field_annotation_is_scalar(type_annotation)
# Root models by default regarded as bodies
or field_annotation_is_root_model(type_annotation)
):
field_info = params.Body(annotation=use_annotation, default=default_value)
else:
field_info = params.Query(annotation=use_annotation, default=default_value)
@ -752,8 +758,9 @@ def request_params_to_args(
single_not_embedded_field = False
default_convert_underscores = True
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
single_not_embedded_field = True
if field_annotation_has_submodel_fields(first_field.type_):
fields_to_extract = get_cached_model_fields(first_field.type_)
single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
default_convert_underscores = getattr(
@ -916,7 +923,9 @@ async def request_body_to_args(
fields_to_extract: List[ModelField] = body_fields
if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
if single_not_embedded_field and field_annotation_has_submodel_fields(
first_field.type_
):
fields_to_extract = get_cached_model_fields(first_field.type_)
if isinstance(received_body, FormData):

14
fastapi/encoders.py

@ -220,7 +220,7 @@ def jsonable_encoder(
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders.update(custom_encoder)
obj_dict = _model_dump(
serialized = _model_dump(
obj,
mode="json",
include=include,
@ -230,10 +230,16 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
if (
not PYDANTIC_V2
and isinstance(serialized, dict)
and "__root__" in serialized
):
serialized = serialized["__root__"]
return jsonable_encoder(
obj_dict,
serialized,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1

519
tests/test_root_model.py

@ -0,0 +1,519 @@
from datetime import date
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI, Header, Path, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from .utils import needs_pydanticv2
app = FastAPI()
if PYDANTIC_V2:
from pydantic import RootModel
else:
from pydantic import BaseModel
T = TypeVar("T")
class RootModel(Generic[T]):
def __class_getitem__(cls, type_):
class Inner(BaseModel):
__root__: type_
return Inner
Basic = RootModel[int]
class DictWrap(RootModel[Dict[str, int]]):
pass
if PYDANTIC_V2:
from pydantic import ConfigDict, Field, field_validator, model_serializer
class FieldWrap(RootModel[str]):
model_config = ConfigDict(
json_schema_extra={"description": "parameter starts with bar_"}
)
root: str = Field(pattern=r"^bar_.*$")
class CustomParsed(RootModel[str]):
model_config = ConfigDict(
json_schema_extra={"description": "parameter starts with foo_"}
)
@field_validator("root")
@classmethod
def validator(cls, v: str) -> str:
if not v.startswith("foo_"):
raise ValueError("must start with foo_")
return v
@model_serializer
def serialize(self):
return (
self.root
if self.root.endswith("_serialized")
else f"{self.root}_serialized"
)
def parse(self):
return self.root[len("foo_") :] # removeprefix
else:
from pydantic import validator
class FieldWrap(BaseModel):
class Config:
schema_extra = {"description": "parameter starts with bar_"}
__root__: str = Field(regex=r"^bar_.*$")
class CustomParsed(BaseModel):
class Config:
schema_extra = {"description": "parameter starts with foo_"}
__root__: str
@validator("__root__")
@classmethod
def validator(cls, v: str) -> str:
if not v.startswith("foo_"):
raise ValueError("must start with foo_")
return v
def dict(self, **kwargs: Any) -> Dict[str, Any]:
return {
"__root__": self.__root__
if self.__root__.endswith("_serialized")
else f"{self.__root__}_serialized"
}
def parse(self):
return self.__root__[len("foo_") :] # removeprefix
@app.get("/query/basic")
def query_basic(q: Annotated[Basic, Query()]):
return {"q": q}
@app.get("/query/fieldwrap")
def query_fieldwrap(q: Annotated[FieldWrap, Query()]):
return {"q": q}
@app.get("/query/customparsed")
def query_customparsed(q: Annotated[CustomParsed, Query()]):
return {"q": q.parse()}
@app.get("/path/basic/{p}")
def path_basic(p: Annotated[Basic, Path()]):
return {"p": p}
@app.get("/path/fieldwrap/{p}")
def path_fieldwrap(p: Annotated[FieldWrap, Path()]):
return {"p": p}
@app.get("/path/customparsed/{p}")
def path_customparsed(p: Annotated[CustomParsed, Path()]):
return {"p": p.parse()}
@app.post("/body/basic")
def body_basic(b: Annotated[Basic, Body()]):
return {"b": b}
@app.post("/body/fieldwrap")
def body_fieldwrap(b: Annotated[FieldWrap, Body()]):
return {"b": b}
@app.post("/body/customparsed")
def body_customparsed(b: Annotated[CustomParsed, Body()]):
return {"b": b.parse()}
@app.post("/body_default/basic")
def body_default_basic(b: Basic):
return {"b": b}
@app.post("/body_default/fieldwrap")
def body_default_fieldwrap(b: FieldWrap):
return {"b": b}
@app.post("/body_default/customparsed")
def body_default_customparsed(b: CustomParsed):
return {"b": b.parse()}
@app.get("/echo/basic")
def echo_basic(q: Annotated[Basic, Query()]) -> Basic:
return q
@app.get("/echo/fieldwrap")
def echo_fieldwrap(q: Annotated[FieldWrap, Query()]) -> FieldWrap:
return q
@app.get("/echo/customparsed")
def echo_customparsed(q: Annotated[CustomParsed, Query()]) -> CustomParsed:
return q
@app.post("/echo/dictwrap")
def echo_dictwrap(b: DictWrap) -> DictWrap:
return b.model_dump() if PYDANTIC_V2 else b.dict()["__root__"]
client = TestClient(app)
@pytest.mark.parametrize(
"url, response_json, request_body",
[
("/query/basic?q=42", {"q": 42}, None),
("/query/fieldwrap?q=bar_baz", {"q": "bar_baz"}, None),
("/query/customparsed?q=foo_bar", {"q": "bar"}, None),
("/path/basic/42", {"p": 42}, None),
("/path/fieldwrap/bar_baz", {"p": "bar_baz"}, None),
("/path/customparsed/foo_bar", {"p": "bar"}, None),
("/body/basic", {"b": 42}, "42"),
("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
("/body/customparsed", {"b": "bar"}, "foo_bar"),
("/body_default/basic", {"b": 42}, "42"),
("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
("/body_default/customparsed", {"b": "bar"}, "foo_bar"),
("/echo/basic?q=42", 42, None),
("/echo/fieldwrap?q=bar_baz", "bar_baz", None),
("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None),
("/echo/dictwrap", {"test": 1}, {"test": 1}),
],
)
def test_root_model_200(url: str, response_json: Any, request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 200, response.text
assert response.json() == response_json
@pytest.mark.parametrize("type_", [int, str, bytes, Optional[int]])
@pytest.mark.parametrize("outer", [None, List, Sequence])
def test2_root_model_200__basic(type_: Type, outer: Optional[Type]):
inner = outer[type_] if outer else type_
Model = RootModel[inner]
app2 = FastAPI()
if outer is None:
@app2.get("/path/{p}")
def path_basic2(p: Annotated[Model, Path()]) -> Model:
return p
else:
with pytest.raises(
AssertionError, match="Path params must be of one of the supported types"
):
@app2.get("/path/{p}")
def path_basic2(p: Annotated[Model, Path()]) -> Model:
return p # pragma: nocover
@app2.get("/query")
def query_basic2(q: Annotated[Model, Query()]) -> Model:
return q
@app2.get("/header")
def path_basic2(h: Annotated[Model, Header()]) -> Model:
return h
@app2.post("/body")
def body_basic2(b: Annotated[Model, Body()]) -> Model:
return b
client2 = TestClient(app2)
if outer is None:
expected = "42" if type_ in (str, bytes) else 42
else:
expected = ["42", "43"] if type_ in (str, bytes) else [42, 43]
if outer is None:
assert client2.get("/path/42").json() == expected
assert client2.get("/query?q=42").json() == expected
assert client2.get("/header", headers={"h": "42"}).json() == expected
else:
assert client2.get("/path/42").json()["detail"] == "Not Found"
assert client2.get("/query?q=42&q=43").json() == expected
assert (
client2.get("/header", headers=[("h", "42"), ("h", "43")]).json()
== expected
)
assert client2.post("/body", json=expected).json() == expected
@pytest.mark.parametrize("left", [RootModel[date], date])
@pytest.mark.parametrize("right", [RootModel[int], int])
@needs_pydanticv2
def test_root_model_union(left: Any, right: Any):
Model = Union[left, right]
app2 = FastAPI()
@app2.get("/query")
def handler1(q: Annotated[Model, Query()]):
return {"q": q}
@app2.post("/body")
def handler2(b: Annotated[Model, Body()]):
return {"b": b}
client2 = TestClient(app2)
for val in ["2025-01-02", 42]:
response = client2.get(f"/query?q={val}")
assert response.status_code == 200, response.text
assert response.json() == {"q": val}
response = client2.post("/body", json=val)
assert response.status_code == 200, response.text
assert response.json() == {"b": val}
response = client2.get("/query?q=bad")
assert response.status_code == 422, response.text
assert {detail["msg"] for detail in response.json()["detail"]} == {
"Input should be a valid date or datetime, input is too short",
"Input should be a valid integer, unable to parse string as an integer",
}
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/basic?q=my_bad_not_int", ["query", "q"], None),
("/path/basic/my_bad_not_int", ["path", "p"], None),
("/body/basic", ["body"], "my_bad_not_int"),
("/body_default/basic", ["body"], "my_bad_not_int"),
],
)
def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": error_path,
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "my_bad_not_int",
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
}
)
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None),
("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None),
("/body/fieldwrap", ["body"], "my_bad_prefix_val"),
("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "string_pattern_mismatch",
"loc": error_path,
"msg": "String should match pattern '^bar_.*$'",
"input": "my_bad_prefix_val",
"ctx": {"pattern": "^bar_.*$"},
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": 'string does not match regex "^bar_.*$"',
"type": "value_error.str.regex",
"ctx": {"pattern": "^bar_.*$"},
}
]
}
)
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None),
("/path/customparsed/my_bad_prefix_val", ["path", "p"], None),
("/body/customparsed", ["body"], "my_bad_prefix_val"),
("/body_default/customparsed", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_customparsed_422(
url: str, error_path: List[str], request_body: Any
):
response = client.post(url, json=request_body) if request_body else client.get(url)
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "value_error",
"loc": error_path,
"msg": "Value error, must start with foo_",
"input": "my_bad_prefix_val",
"ctx": {"error": {}},
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": [*error_path, "__root__"],
"msg": "must start with foo_",
"type": "value_error",
}
]
}
)
def test_root_model_dictwrap_422():
response = client.post("/echo/dictwrap", json={"test": "fail_not_int"})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": ["body", "test"],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "fail_not_int",
}
]
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["body", "__root__", "test"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
]
}
)
@pytest.mark.parametrize(
"model, path_name, expected_model_schema",
[
(Basic, "basic", {"type": "integer"}),
(
FieldWrap,
"fieldwrap",
{
"type": "string",
"pattern": "^bar_.*$",
"description": "parameter starts with bar_",
},
),
(
CustomParsed,
"customparsed",
{
"type": "string",
"description": "parameter starts with foo_",
},
),
],
)
def test_openapi_schema(
model: Type, path_name: str, expected_model_schema: Dict[str, Any]
):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
paths = response.json()["paths"]
ref_name = model.__name__.replace("[", "_").replace("]", "_")
schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}}
assert paths[f"/query/{path_name}"]["get"]["parameters"] == [
{"in": "query", "name": "q", "required": True, **schema_ref}
]
assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [
{"in": "path", "name": "p", "required": True, **schema_ref}
]
assert paths[f"/body/{path_name}"]["post"]["requestBody"] == {
"content": {"application/json": schema_ref},
"required": True,
}
assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == {
"content": {"application/json": schema_ref},
"required": True,
}
assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == {
"content": {"application/json": schema_ref},
"description": "Successful Response",
}
model_schema = response.json()["components"]["schemas"][ref_name]
model_schema.pop("title")
assert model_schema == expected_model_schema
def test_openapi_schema_dictwrap():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
operation = response.json()["paths"]["/echo/dictwrap"]["post"]
ref = {"schema": {"$ref": "#/components/schemas/DictWrap"}}
assert operation["requestBody"] == {
"content": {"application/json": ref},
"required": True,
}
assert operation["responses"]["200"] == {
"content": {"application/json": ref},
"description": "Successful Response",
}
assert response.json()["components"]["schemas"]["DictWrap"] == {
"title": "DictWrap",
"type": "object",
"additionalProperties": {"type": "integer"},
}
Loading…
Cancel
Save