Browse Source

Add PYDANTIC_V2 check to __root__ check. Use type alias in basic test

pull/11306/head
Markus Sintonen 1 year ago
parent
commit
65f1ba05e5
  1. 8
      fastapi/encoders.py
  2. 52
      tests/test_root_model.py

8
fastapi/encoders.py

@ -230,10 +230,14 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if (
isinstance(serialized, dict) and "__root__" in serialized
): # TODO: remove when deprecating Pydantic v1
not PYDANTIC_V2
and isinstance(serialized, dict)
and "__root__" in serialized
):
serialized = serialized["__root__"]
return jsonable_encoder(
serialized,
exclude_none=exclude_none,

52
tests/test_root_model.py

@ -5,7 +5,6 @@ from dirty_equals import IsDict
from fastapi import Body, FastAPI, Path, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from fastapi.utils import match_pydantic_error_url
from pydantic import BaseModel
app = FastAPI()
@ -13,8 +12,7 @@ app = FastAPI()
if PYDANTIC_V2:
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer
class Basic(RootModel[int]):
pass
Basic = RootModel[int]
class FieldWrap(RootModel[str]):
model_config = ConfigDict(
@ -256,7 +254,6 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any
"loc": error_path,
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "my_bad_not_int",
"url": match_pydantic_error_url("int_parsing"),
}
]
}
@ -294,7 +291,6 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body:
"loc": error_path,
"msg": "String should match pattern '^bar_.*$'",
"input": "my_bad_prefix_val",
"url": match_pydantic_error_url("string_pattern_mismatch"),
"ctx": {"pattern": "^bar_.*$"},
}
]
@ -337,7 +333,6 @@ def test_root_model_customparsed_422(
"msg": "Value error, must start with foo_",
"input": "my_bad_prefix_val",
"ctx": {"error": {}},
"url": match_pydantic_error_url("value_error"),
}
]
}
@ -366,7 +361,6 @@ def test_root_model_dictwrap_422():
"loc": ["body", "test"],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "fail_not_int",
"url": match_pydantic_error_url("int_parsing"),
}
]
}
@ -385,13 +379,13 @@ def test_root_model_dictwrap_422():
@pytest.mark.parametrize(
"model, model_schema",
"model, path_name, expected_model_schema",
[
(Basic, {"title": "Basic", "type": "integer"}),
(Basic, "basic", {"type": "integer"}),
(
FieldWrap,
"fieldwrap",
{
"title": "FieldWrap",
"type": "string",
"pattern": "^bar_.*$",
"description": "parameter starts with bar_",
@ -399,42 +393,46 @@ def test_root_model_dictwrap_422():
),
(
CustomParsed,
"customparsed",
{
"title": "CustomParsed",
"type": "string",
"description": "parameter starts with foo_",
},
),
],
)
def test_openapi_schema(model: Type, model_schema: Dict[str, Any]):
def test_openapi_schema(
model: Type, path_name: str, expected_model_schema: Dict[str, Any]
):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
paths = response.json()["paths"]
ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}}
assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [
{"in": "query", "name": "q", "required": True, **ref}
ref_name = model.__name__.replace("[", "_").replace("]", "_")
schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}}
assert paths[f"/query/{path_name}"]["get"]["parameters"] == [
{"in": "query", "name": "q", "required": True, **schema_ref}
]
assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [
{"in": "path", "name": "p", "required": True, **ref}
assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [
{"in": "path", "name": "p", "required": True, **schema_ref}
]
assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == {
"content": {"application/json": ref},
assert paths[f"/body/{path_name}"]["post"]["requestBody"] == {
"content": {"application/json": schema_ref},
"required": True,
}
assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == {
"content": {"application/json": ref},
assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == {
"content": {"application/json": schema_ref},
"required": True,
}
assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == {
"content": {"application/json": ref},
assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == {
"content": {"application/json": schema_ref},
"description": "Successful Response",
}
assert response.json()["components"]["schemas"][model.__name__] == {
"title": model.__name__,
**model_schema,
}
model_schema = response.json()["components"]["schemas"][ref_name]
model_schema.pop("title")
assert model_schema == expected_model_schema
def test_openapi_schema_dictwrap():

Loading…
Cancel
Save