Browse Source

Use TypeAdapter.validate_json for Pydantic v2

pull/13951/head
Martynov Maxim 3 days ago
parent
commit
3160247e21
No known key found for this signature in database GPG Key ID: 9C23E39F5BBC88CC
  1. 63
      fastapi/_compat.py
  2. 53
      fastapi/dependencies/utils.py
  3. 46
      fastapi/routing.py
  4. 15
      tests/test_tutorial/test_body/test_tutorial001.py
  5. 2
      tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py
  6. 28
      tests/test_tutorial/test_handling_errors/test_tutorial005.py

63
fastapi/_compat.py

@ -1,3 +1,4 @@
import json
from collections import deque
from copy import copy
from dataclasses import dataclass, is_dataclass
@ -135,6 +136,22 @@ if PYDANTIC_V2:
errors=exc.errors(include_url=False), loc_prefix=loc
)
def validate_json(
self,
value: Any,
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_json(value),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize(
self,
value: Any,
@ -292,6 +309,18 @@ if PYDANTIC_V2:
for name, field_info in model.model_fields.items()
]
def parse_json_field(
field: ModelField,
data: bytes,
values: Dict[str, Any],
loc: Tuple[str, ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
value, errors = field.validate_json(data)
if isinstance(errors, list):
new_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=loc)
return None, new_errors
return value, errors
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
@ -530,6 +559,17 @@ else:
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def parse_json_field(
field: ModelField,
data: bytes,
values: Dict[str, Any],
loc: Tuple[str, ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
parsed_value, errors = parse_json(data, loc)
if errors:
return None, errors
return field.validate(parsed_value, values, loc=loc)
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@ -662,3 +702,26 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
@lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return get_model_fields(model)
def parse_json(
data: bytes,
loc: Tuple[str, ...],
) -> Tuple[Any, List[Dict[str, Any]]]:
try:
return json.loads(data), []
except json.JSONDecodeError as e:
return None, [
{
"type": "value_error.jsondecode",
"loc": loc + (e.pos,),
"msg": f"Invalid JSON: {e.msg}",
"ctx": {
"msg": e.msg,
"doc": data,
"pos": e.pos,
"lineno": e.lineno,
"colno": e.colno,
},
}
]

53
fastapi/dependencies/utils.py

@ -42,15 +42,14 @@ from fastapi._compat import (
is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation,
lenient_issubclass,
parse_json,
parse_json_field,
sequence_types,
serialize_sequence_value,
value_is_sequence,
)
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
@ -573,7 +572,8 @@ async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
body: Optional[Union[bytes, Dict[str, Any], FormData]] = None,
is_body_json: bool = False,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
@ -616,6 +616,7 @@ async def solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
is_body_json=is_body_json,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
@ -666,6 +667,7 @@ async def solve_dependencies(
) = await request_body_to_args( # body_params checked above
body_fields=dependant.body_params,
received_body=body,
is_body_json=is_body_json,
embed_body_fields=embed_body_fields,
)
values.update(body_values)
@ -695,6 +697,24 @@ async def solve_dependencies(
)
def _validate_json_body_as_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
if value is None:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
v_, errors_ = parse_json_field(field, value, values=values, loc=loc)
if isinstance(errors_, ErrorWrapper):
return None, [errors_]
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
return None, new_errors
else:
return v_, []
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
@ -904,7 +924,8 @@ async def _extract_form_body(
async def request_body_to_args(
body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
received_body: Optional[Union[bytes, Dict[str, Any], FormData]],
is_body_json: bool,
embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values: Dict[str, Any] = {}
@ -924,16 +945,28 @@ async def request_body_to_args(
if single_not_embedded_field:
loc: Tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
if is_body_json:
v_, errors_ = _validate_json_body_as_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
else:
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
if is_body_json and isinstance(received_body, bytes):
body_to_process, errors = parse_json(received_body, loc=("body",))
if errors:
return values, errors
for field in body_fields:
loc = ("body", field.alias)
value: Optional[Any] = None
if body_to_process is not None:
try:
value = body_to_process.get(field.alias)
value = body_to_process.get(field.alias) # type: ignore[union-attr]
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))

46
fastapi/routing.py

@ -2,7 +2,6 @@ import asyncio
import dataclasses
import email.message
import inspect
import json
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from typing import (
@ -25,7 +24,6 @@ from typing import (
from fastapi import params
from fastapi._compat import (
ModelField,
Undefined,
_get_model_config,
_model_dump,
_normalize_errors,
@ -243,42 +241,23 @@ def get_request_handler(
async with AsyncExitStack() as file_stack:
try:
body: Any = None
is_body_json = False
if body_field:
if is_body_form:
body = await request.form()
file_stack.push_async_callback(body.close)
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
],
body=e.doc,
)
raise validation_error from e
body = await request.body() or None
if body:
message = email.message.Message()
message["content-type"] = request.headers.get(
"content-type", "application/json"
)
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
is_body_json = subtype == "json" or subtype.endswith(
"+json"
)
except HTTPException:
# If a middleware raises an HTTPException, it should be raised again
raise
@ -293,6 +272,7 @@ def get_request_handler(
request=request,
dependant=dependant,
body=body,
is_body_json=is_body_json,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,

15
tests/test_tutorial/test_body/test_tutorial001.py

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from dirty_equals import IsDict
from fastapi.testclient import TestClient
from starlette.routing import Request
from ...utils import needs_py310
@ -206,12 +207,10 @@ def test_post_broken_body(client: TestClient):
"detail": [
{
"type": "json_invalid",
"loc": ["body", 1],
"msg": "JSON decode error",
"input": {},
"ctx": {
"error": "Expecting property name enclosed in double quotes"
},
"loc": ["body"],
"msg": "Invalid JSON: key must be a string at line 1 column 2",
"input": "{some broken json}",
"ctx": {"error": "key must be a string at line 1 column 2"},
}
]
}
@ -221,7 +220,7 @@ def test_post_broken_body(client: TestClient):
"detail": [
{
"loc": ["body", 1],
"msg": "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
"msg": "Invalid JSON: Expecting property name enclosed in double quotes",
"type": "value_error.jsondecode",
"ctx": {
"msg": "Expecting property name enclosed in double quotes",
@ -383,7 +382,7 @@ def test_wrong_headers(client: TestClient):
def test_other_exceptions(client: TestClient):
with patch("json.loads", side_effect=Exception):
with patch.object(Request, "body", side_effect=Exception):
response = client.post("/items/", json={"test": "test2"})
assert response.status_code == 400, response.text

2
tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py

@ -20,7 +20,7 @@ def test_exception_handler_body_access():
{
"type": "list_type",
"loc": ["body"],
"msg": "Input should be a valid list",
"msg": "Input should be a valid array",
"input": {"numbers": [1, 2, 3]},
}
],

28
tests/test_tutorial/test_handling_errors/test_tutorial005.py

@ -19,7 +19,31 @@ def test_post_validation_error():
"input": "XL",
}
],
"body": {"title": "towel", "size": "XL"},
"body": '{"title":"towel","size":"XL"}',
}
) | IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": ["body", "size"],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "XL",
}
],
"body": '{"title": "towel", "size": "XL"}',
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["body", "size"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
],
"body": '{"title":"towel","size":"XL"}',
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
@ -31,7 +55,7 @@ def test_post_validation_error():
"type": "type_error.integer",
}
],
"body": {"title": "towel", "size": "XL"},
"body": '{"title": "towel", "size": "XL"}',
}
)

Loading…
Cancel
Save