From 65f1ba05e56bf9068bcebbc9de2cc25d57d9aef0 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Wed, 24 Apr 2024 20:58:14 +0300 Subject: [PATCH] Add PYDANTIC_V2 check to __root__ check. Use type alias in basic test --- fastapi/encoders.py | 8 +++++-- tests/test_root_model.py | 52 +++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/fastapi/encoders.py b/fastapi/encoders.py index aa5e04fc6..bc761a0d5 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -230,10 +230,14 @@ def jsonable_encoder( exclude_none=exclude_none, exclude_defaults=exclude_defaults, ) + if ( - isinstance(serialized, dict) and "__root__" in serialized - ): # TODO: remove when deprecating Pydantic v1 + not PYDANTIC_V2 + and isinstance(serialized, dict) + and "__root__" in serialized + ): serialized = serialized["__root__"] + return jsonable_encoder( serialized, exclude_none=exclude_none, diff --git a/tests/test_root_model.py b/tests/test_root_model.py index e2ba6e028..5c91bc10a 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -5,7 +5,6 @@ from dirty_equals import IsDict from fastapi import Body, FastAPI, Path, Query from fastapi._compat import PYDANTIC_V2 from fastapi.testclient import TestClient -from fastapi.utils import match_pydantic_error_url from pydantic import BaseModel app = FastAPI() @@ -13,8 +12,7 @@ app = FastAPI() if PYDANTIC_V2: from pydantic import ConfigDict, Field, RootModel, field_validator, model_serializer - class Basic(RootModel[int]): - pass + Basic = RootModel[int] class FieldWrap(RootModel[str]): model_config = ConfigDict( @@ -256,7 +254,6 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any "loc": error_path, "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "my_bad_not_int", - "url": match_pydantic_error_url("int_parsing"), } ] } @@ -294,7 +291,6 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: "loc": error_path, "msg": "String should match pattern '^bar_.*$'", "input": "my_bad_prefix_val", - "url": match_pydantic_error_url("string_pattern_mismatch"), "ctx": {"pattern": "^bar_.*$"}, } ] @@ -337,7 +333,6 @@ def test_root_model_customparsed_422( "msg": "Value error, must start with foo_", "input": "my_bad_prefix_val", "ctx": {"error": {}}, - "url": match_pydantic_error_url("value_error"), } ] } @@ -366,7 +361,6 @@ def test_root_model_dictwrap_422(): "loc": ["body", "test"], "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "fail_not_int", - "url": match_pydantic_error_url("int_parsing"), } ] } @@ -385,13 +379,13 @@ def test_root_model_dictwrap_422(): @pytest.mark.parametrize( - "model, model_schema", + "model, path_name, expected_model_schema", [ - (Basic, {"title": "Basic", "type": "integer"}), + (Basic, "basic", {"type": "integer"}), ( FieldWrap, + "fieldwrap", { - "title": "FieldWrap", "type": "string", "pattern": "^bar_.*$", "description": "parameter starts with bar_", @@ -399,42 +393,46 @@ def test_root_model_dictwrap_422(): ), ( CustomParsed, + "customparsed", { - "title": "CustomParsed", "type": "string", "description": "parameter starts with foo_", }, ), ], ) -def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): +def test_openapi_schema( + model: Type, path_name: str, expected_model_schema: Dict[str, Any] +): response = client.get("/openapi.json") assert response.status_code == 200, response.text paths = response.json()["paths"] - ref = {"schema": {"$ref": f"#/components/schemas/{model.__name__}"}} - assert paths[f"/query/{model.__name__.lower()}"]["get"]["parameters"] == [ - {"in": "query", "name": "q", "required": True, **ref} + ref_name = model.__name__.replace("[", "_").replace("]", "_") + schema_ref = {"schema": {"$ref": f"#/components/schemas/{ref_name}"}} + + assert paths[f"/query/{path_name}"]["get"]["parameters"] == [ + {"in": "query", "name": "q", "required": True, **schema_ref} ] - assert paths[f"/path/{model.__name__.lower()}/{{p}}"]["get"]["parameters"] == [ - {"in": "path", "name": "p", "required": True, **ref} + assert paths[f"/path/{path_name}/{{p}}"]["get"]["parameters"] == [ + {"in": "path", "name": "p", "required": True, **schema_ref} ] - assert paths[f"/body/{model.__name__.lower()}"]["post"]["requestBody"] == { - "content": {"application/json": ref}, + assert paths[f"/body/{path_name}"]["post"]["requestBody"] == { + "content": {"application/json": schema_ref}, "required": True, } - assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == { - "content": {"application/json": ref}, + assert paths[f"/body_default/{path_name}"]["post"]["requestBody"] == { + "content": {"application/json": schema_ref}, "required": True, } - assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { - "content": {"application/json": ref}, + assert paths[f"/echo/{path_name}"]["get"]["responses"]["200"] == { + "content": {"application/json": schema_ref}, "description": "Successful Response", } - assert response.json()["components"]["schemas"][model.__name__] == { - "title": model.__name__, - **model_schema, - } + + model_schema = response.json()["components"]["schemas"][ref_name] + model_schema.pop("title") + assert model_schema == expected_model_schema def test_openapi_schema_dictwrap():