Browse Source

Add PYDANTIC_V2 check to __root__ check. Use type alias in basic test

pull/11306/head
Markus Sintonen 1 year ago
parent
commit
65f1ba05e5
  1. 8
      fastapi/encoders.py
  2. 52
      tests/test_root_model.py

8
fastapi/encoders.py

@ -230,10 +230,14 @@ def jsonable_encoder(
exclude_none=exclude_none, exclude_none=exclude_none,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
) )
if ( if (
isinstance(serialized, dict) and "__root__" in serialized not PYDANTIC_V2
): # TODO: remove when deprecating Pydantic v1 and isinstance(serialized, dict)
and "__root__" in serialized
):
serialized = serialized["__root__"] serialized = serialized["__root__"]
return jsonable_encoder( return jsonable_encoder(
serialized, serialized,
exclude_none=exclude_none, exclude_none=exclude_none,

52
tests/test_root_model.py

@ -5,7 +5,6 @@ from dirty_equals import IsDict
from fastapi import Body, FastAPI, Path, Query from fastapi import Body, FastAPI, Path, Query
from fastapi._compat import PYDANTIC_V2 from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi.utils import match_pydantic_error_url
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI() app = FastAPI()
@ -13,8 +12,7 @@ app = FastAPI()
if PYDANTIC_V2: if PYDANTIC_V2:
from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer
class Basic(RootModel[int]): Basic = RootModel[int]
pass
class FieldWrap(RootModel[str]): class FieldWrap(RootModel[str]):
model_config = ConfigDict( model_config = ConfigDict(
@ -256,7 +254,6 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any
"loc": error_path, "loc": error_path,
"msg": "Input should be a valid integer, unable to parse string as an integer", "msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "my_bad_not_int", "input": "my_bad_not_int",
"url": match_pydantic_error_url("int_parsing"),
} }
] ]
} }
@ -294,7 +291,6 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body:
"loc": error_path, "loc": error_path,
"msg": "String should match pattern '^bar_.*$'", "msg": "String should match pattern '^bar_.*$'",
"input": "my_bad_prefix_val", "input": "my_bad_prefix_val",
"url": match_pydantic_error_url("string_pattern_mismatch"),
"ctx": {"pattern": "^bar_.*$"}, "ctx": {"pattern": "^bar_.*$"},
} }
] ]
@ -337,7 +333,6 @@ def test_root_model_customparsed_422(
"msg": "Value error, must start with foo_", "msg": "Value error, must start with foo_",
"input": "my_bad_prefix_val", "input": "my_bad_prefix_val",
"ctx": {"error": {}}, "ctx": {"error": {}},
"url": match_pydantic_error_url("value_error"),
} }
] ]
} }
@ -366,7 +361,6 @@ def test_root_model_dictwrap_422():
"loc": ["body", "test"], "loc": ["body", "test"],
"msg": "Input should be a valid integer, unable to parse string as an integer", "msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "fail_not_int", "input": "fail_not_int",
"url": match_pydantic_error_url("int_parsing"),
} }
] ]
} }
@ -385,13 +379,13 @@ def test_root_model_dictwrap_422():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, model_schema", "model, path_name, expected_model_schema",
[ [
(Basic, {"title": "Basic", "type": "integer"}), (Basic, "basic", {"type": "integer"}),
( (
FieldWrap, FieldWrap,
"fieldwrap",
{ {
"title": "FieldWrap",
"type": "string", "type": "string",
"pattern": "^bar_.*$", "pattern": "^bar_.*$",
"description": "parameter starts with bar_", "description": "parameter starts with bar_",
@ -399,42 +393,46 @@ def test_root_model_dictwrap_422():
), ),
( (
CustomParsed, CustomParsed,
"customparsed",
{ {
"title": "CustomParsed",
"type": "string", "type": "string",
"description": "parameter starts with foo_", "description": "parameter starts with foo_",
}, },
), ),
], ],
) )
def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): def test_openapi_schema(
model: Type, path_name: str, expected_model_schema: Dict[str, Any]
):
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
paths = response.json()["paths"] paths = response.json()["paths"]
ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}} ref_name = model.__name__.replace("[", "_").replace("]", "_")
assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [ schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}}
{"in": "query", "name": "q", "required": True, **ref}
assert paths[f"/query/{path_name}"]["get"]["parameters"] == [
{"in": "query", "name": "q", "required": True, **schema_ref}
] ]
assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [ assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [
{"in": "path", "name": "p", "required": True, **ref} {"in": "path", "name": "p", "required": True, **schema_ref}
] ]
assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == { assert paths[f"/body/{path_name}"]["post"]["requestBody"] == {
"content": {"application/json": ref}, "content": {"application/json": schema_ref},
"required": True, "required": True,
} }
assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == { assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == {
"content": {"application/json": ref}, "content": {"application/json": schema_ref},
"required": True, "required": True,
} }
assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == {
"content": {"application/json": ref}, "content": {"application/json": schema_ref},
"description": "Successful Response", "description": "Successful Response",
} }
assert response.json()["components"]["schemas"][model.__name__] == {
"title": model.__name__, model_schema = response.json()["components"]["schemas"][ref_name]
**model_schema, model_schema.pop("title")
} assert model_schema == expected_model_schema
def test_openapi_schema_dictwrap(): def test_openapi_schema_dictwrap():

Loading…
Cancel
Save