Browse Source

Root model by default again as body

pull/11306/head
Markus Sintonen 1 year ago
parent
commit
320084c609
  1. 20
      fastapi/_compat.py
  2. 7
      fastapi/dependencies/utils.py
  3. 67
      tests/test_root_model.py

20
fastapi/_compat.py

@ -239,11 +239,7 @@ if PYDANTIC_V2:
return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
return field_annotation_is_scalar(field.field_info.annotation)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
@ -505,11 +501,7 @@ else:
)
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return is_pv1_scalar_field(field) and not isinstance(
field.field_info, params.Body
)
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
@ -595,6 +587,14 @@ def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
)
def field_annotation_is_root_model(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_root_model(arg) for arg in get_args(annotation))
return root_model_inner_type(annotation) is not None
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:

7
fastapi/dependencies/utils.py

@ -30,6 +30,7 @@ from fastapi._compat import (
copy_field_info,
create_body_model,
evaluate_forwardref,
field_annotation_is_root_model,
field_annotation_is_scalar,
get_annotation_from_field_info,
get_cached_model_fields,
@ -454,7 +455,11 @@ def analyze_param(
type_annotation
) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation):
elif (
not field_annotation_is_scalar(type_annotation)
# Root models by default regarded as bodies
or field_annotation_is_root_model(type_annotation)
):
field_info = params.Body(annotation=use_annotation, default=default_value)
else:
field_info = params.Query(annotation=use_annotation, default=default_value)

67
tests/test_root_model.py

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Type, Union
import pytest
from dirty_equals import IsDict
@ -131,6 +131,21 @@ def body_customparsed(b: CustomParsed = Body()):
return {"b": b.parse()}
@app.post("/body_default/basic")
def body_default_basic(b: Basic):
return {"b": b}
@app.post("/body_default/fieldwrap")
def body_default_fieldwrap(b: FieldWrap):
return {"b": b}
@app.post("/body_default/customparsed")
def body_default_customparsed(b: CustomParsed):
return {"b": b.parse()}
@app.get("/echo/basic")
def echo_basic(q: Basic = Query()) -> Basic:
return q
@ -166,6 +181,9 @@ client = TestClient(app)
("/body/basic", {"b": 42}, "42"),
("/body/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
("/body/customparsed", {"b": "bar"}, "foo_bar"),
("/body_default/basic", {"b": 42}, "42"),
("/body_default/fieldwrap", {"b": "bar_baz"}, "bar_baz"),
("/body_default/customparsed", {"b": "bar"}, "foo_bar"),
("/echo/basic?q=42", 42, None),
("/echo/fieldwrap?q=bar_baz", "bar_baz", None),
("/echo/customparsed?q=foo_bar", "foo_bar_serialized", None),
@ -178,12 +196,53 @@ def test_root_model_200(url: str, response_json: Any, request_body: Any):
assert response.json() == response_json
def test_root_model_union():
if PYDANTIC_V2:
from pydantic import RootModel
RootModelInt = RootModel[int]
RootModelStr = RootModel[str]
else:
class RootModelInt(BaseModel):
__root__: int
class RootModelStr(BaseModel):
__root__: str
app2 = FastAPI()
@app2.post("/union")
def union_handler(b: Union[RootModelInt, RootModelStr]):
return {"b": b}
client2 = TestClient(app2)
for body in [42, "foo"]:
response = client2.post("/union", json=body)
assert response.status_code == 200, response.text
assert response.json() == {"b": body}
response = client2.post("/union", json=["bad_list"])
assert response.status_code == 422, response.text
if PYDANTIC_V2:
assert {detail["msg"] for detail in response.json()["detail"]} == {
"Input should be a valid integer",
"Input should be a valid string",
}
else:
assert {detail["msg"] for detail in response.json()["detail"]} == {
"value is not a valid integer",
"str type expected",
}
@pytest.mark.parametrize(
"url, error_path, request_body",
[
("/query/basic?q=my_bad_not_int", ["query", "q"], None),
("/path/basic/my_bad_not_int", ["path", "p"], None),
("/body/basic", ["body"], "my_bad_not_int"),
("/body_default/basic", ["body"], "my_bad_not_int"),
],
)
def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any):
@ -221,6 +280,7 @@ def test_root_model_basic_422(url: str, error_path: List[str], request_body: Any
("/query/fieldwrap?q=my_bad_prefix_val", ["query", "q"], None),
("/path/fieldwrap/my_bad_prefix_val", ["path", "p"], None),
("/body/fieldwrap", ["body"], "my_bad_prefix_val"),
("/body_default/fieldwrap", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body: Any):
@ -260,6 +320,7 @@ def test_root_model_fieldwrap_422(url: str, error_path: List[str], request_body:
("/query/customparsed?q=my_bad_prefix_val", ["query", "q"], None),
("/path/customparsed/my_bad_prefix_val", ["path", "p"], None),
("/body/customparsed", ["body"], "my_bad_prefix_val"),
("/body_default/customparsed", ["body"], "my_bad_prefix_val"),
],
)
def test_root_model_customparsed_422(
@ -362,6 +423,10 @@ def test_openapi_schema(model: Type, model_schema: Dict[str, Any]):
"content": {"application/json": ref},
"required": True,
}
assert paths[f"/body_default/{model.__name__.lower()}"]["post"]["requestBody"] == {
"content": {"application/json": ref},
"required": True,
}
assert paths[f"/echo/{model.__name__.lower()}"]["get"]["responses"]["200"] == {
"content": {"application/json": ref},
"description": "Successful Response",

Loading…
Cancel
Save