From 320084c60976784481150584e2a97a0b9212884f Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Mon, 18 Mar 2024 09:03:11 +0200 Subject: [PATCH] Root model by default again as body --- fastapi/_compat.py | 20 +++++------ fastapi/dependencies/utils.py | 7 +++- tests/test_root_model.py | 67 ++++++++++++++++++++++++++++++++++- 3 files changed, 82 insertions(+), 12 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 8f25244e9..bf9e46e68 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -239,11 +239,7 @@ if PYDANTIC_V2: return field_mapping, definitions # type: ignore[return-value] def is_scalar_field(field: ModelField) -> bool: - from fastapi import params - - return field_annotation_is_scalar( - field.field_info.annotation - ) and not isinstance(field.field_info, params.Body) + return field_annotation_is_scalar(field.field_info.annotation) def is_sequence_field(field: ModelField) -> bool: return field_annotation_is_sequence(field.field_info.annotation) @@ -505,11 +501,7 @@ else: ) def is_scalar_field(field: ModelField) -> bool: - from fastapi import params - - return is_pv1_scalar_field(field) and not isinstance( - field.field_info, params.Body - ) + return is_pv1_scalar_field(field) def is_sequence_field(field: ModelField) -> bool: return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] @@ -595,6 +587,14 @@ def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: ) +def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + return any(field_annotation_is_root_model(arg) for arg in get_args(annotation)) + + return root_model_inner_type(annotation) is not None + + def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: origin = get_origin(annotation) if origin is Union or origin is UnionType: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..359004601 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -30,6 +30,7 @@ from fastapi._compat import ( copy_field_info, create_body_model, evaluate_forwardref, + field_annotation_is_root_model, field_annotation_is_scalar, get_annotation_from_field_info, get_cached_model_fields, @@ -454,7 +455,11 @@ def analyze_param( type_annotation ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=use_annotation, default=default_value) - elif not field_annotation_is_scalar(annotation=type_annotation): + elif ( + not field_annotation_is_scalar(type_annotation) + # Root models by default regarded as bodies + or field_annotation_is_root_model(type_annotation) + ): field_info = params.Body(annotation=use_annotation, default=default_value) else: field_info = params.Query(annotation=use_annotation, default=default_value) diff --git a/tests/test_root_model.py b/tests/test_root_model.py index b3d2acc6d..e2ba6e028 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Type, Union import pytest from dirty_equals import IsDict @@ -131,6 +131,21 @@ def body_customparsed(b: CustomParsed = Body()): return {"b": b.parse()} +@app.post("/body_default/basic") +def body_default_basic(b: Basic): + return {"b": b} + + +@app.post("/body_default/fieldwrap") +def body_default_fieldwrap(b: FieldWrap): + return {"b": b} + + +@app.post("/body_default/customparsed") +def body_default_customparsed(b: CustomParsed): + return {"b": b.parse()} + + @app.get("/echo/basic") def echo_basic(q: Basic = Query()) -> Basic: return q @@ -166,6 +181,9 @@ client = TestClient(app) ("/body/basic", {"b": 42}, "42"), ("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"), ("/body/customparsed", {"b": "bar"}, "foo_bar"), + ("/body_default/basic", {"b": 42}, "42"), + ("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"), + ("/body_default/customparsed", {"b": "bar"}, "foo_bar"), ("/echo/basic?q=42", 42, None), ("/echo/fieldwrap?q=bar_baz", "bar_baz", None), ("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None), @@ -178,12 +196,53 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any): assert response.json() == response_json +def test_root_model_union(): + if PYDANTIC_V2: + from pydantic import RootModel + + RootModelInt = RootModel[int] + RootModelStr = RootModel[str] + else: + + class RootModelInt(BaseModel): + __root__: int + + class RootModelStr(BaseModel): + __root__: str + + app2 = FastAPI() + + @app2.post("/union") + def union_handler(b: Union[RootModelInt, RootModelStr]): + return {"b": b} + + client2 = TestClient(app2) + for body in [42, "foo"]: + response = client2.post("/union", json=body) + assert response.status_code == 200, response.text + assert response.json() == {"b": body} + + response = client2.post("/union", json=["bad_list"]) + assert response.status_code == 422, response.text + if PYDANTIC_V2: + assert {detail["msg"] for detail in response.json()["detail"]} == { + "Input should be a valid integer", + "Input should be a valid string", + } + else: + assert {detail["msg"] for detail in response.json()["detail"]} == { + "value is not a valid integer", + "str type expected", + } + + @pytest.mark.parametrize( "url, error_path, request_body", [ ("/query/basic?q=my_bad_not_int", ["query", "q"], None), ("/path/basic/my_bad_not_int", ["path", "p"], None), ("/body/basic", ["body"], "my_bad_not_int"), + ("/body_default/basic", ["body"], "my_bad_not_int"), ], ) def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any): @@ -221,6 +280,7 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any ("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None), ("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None), ("/body/fieldwrap", ["body"], "my_bad_prefix_val"), + ("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"), ], ) def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any): @@ -260,6 +320,7 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: ("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None), ("/path/customparsed/my_bad_prefix_val", ["path", "p"], None), ("/body/customparsed", ["body"], "my_bad_prefix_val"), + ("/body_default/customparsed", ["body"], "my_bad_prefix_val"), ], ) def test_root_model_customparsed_422( @@ -362,6 +423,10 @@ def test_openapi_schema(model: Type, model_schema: Dict[str, Any]): "content": {"application/json": ref}, "required": True, } + assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == { + "content": {"application/json": ref}, + "required": True, + } assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == { "content": {"application/json": ref}, "description": "Successful Response",