Browse Source

Use TypeAdapter.validate_json for Pydantic v2

pull/13951/head
Martynov Maxim 4 days ago
parent
commit
3160247e21
No known key found for this signature in database GPG Key ID: 9C23E39F5BBC88CC
  1. 63
      fastapi/_compat.py
  2. 47
      fastapi/dependencies/utils.py
  3. 38
      fastapi/routing.py
  4. 15
      tests/test_tutorial/test_body/test_tutorial001.py
  5. 2
      tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py
  6. 28
      tests/test_tutorial/test_handling_errors/test_tutorial005.py

63
fastapi/_compat.py

@ -1,3 +1,4 @@
import json
from collections import deque from collections import deque
from copy import copy from copy import copy
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
@ -135,6 +136,22 @@ if PYDANTIC_V2:
errors=exc.errors(include_url=False), loc_prefix=loc errors=exc.errors(include_url=False), loc_prefix=loc
) )
def validate_json(
self,
value: Any,
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_json(value),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize( def serialize(
self, self,
value: Any, value: Any,
@ -292,6 +309,18 @@ if PYDANTIC_V2:
for name, field_info in model.model_fields.items() for name, field_info in model.model_fields.items()
] ]
def parse_json_field(
field: ModelField,
data: bytes,
values: Dict[str, Any],
loc: Tuple[str, ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
value, errors = field.validate_json(data)
if isinstance(errors, list):
new_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=loc)
return None, new_errors
return value, errors
else: else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401 from pydantic import AnyUrl as Url # noqa: F401
@ -530,6 +559,17 @@ else:
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined] return list(model.__fields__.values()) # type: ignore[attr-defined]
def parse_json_field(
field: ModelField,
data: bytes,
values: Dict[str, Any],
loc: Tuple[str, ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
parsed_value, errors = parse_json(data, loc)
if errors:
return None, errors
return field.validate(parsed_value, values, loc=loc)
def _regenerate_error_with_loc( def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@ -662,3 +702,26 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
@lru_cache @lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]: def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return get_model_fields(model) return get_model_fields(model)
def parse_json(
data: bytes,
loc: Tuple[str, ...],
) -> Tuple[Any, List[Dict[str, Any]]]:
try:
return json.loads(data), []
except json.JSONDecodeError as e:
return None, [
{
"type": "value_error.jsondecode",
"loc": loc + (e.pos,),
"msg": f"Invalid JSON: {e.msg}",
"ctx": {
"msg": e.msg,
"doc": data,
"pos": e.pos,
"lineno": e.lineno,
"colno": e.colno,
},
}
]

47
fastapi/dependencies/utils.py

@ -42,15 +42,14 @@ from fastapi._compat import (
is_uploadfile_or_nonable_uploadfile_annotation, is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation, is_uploadfile_sequence_annotation,
lenient_issubclass, lenient_issubclass,
parse_json,
parse_json_field,
sequence_types, sequence_types,
serialize_sequence_value, serialize_sequence_value,
value_is_sequence, value_is_sequence,
) )
from fastapi.background import BackgroundTasks from fastapi.background import BackgroundTasks
from fastapi.concurrency import ( from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
@ -573,7 +572,8 @@ async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
dependant: Dependant, dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None, body: Optional[Union[bytes, Dict[str, Any], FormData]] = None,
is_body_json: bool = False,
background_tasks: Optional[StarletteBackgroundTasks] = None, background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
@ -616,6 +616,7 @@ async def solve_dependencies(
request=request, request=request,
dependant=use_sub_dependant, dependant=use_sub_dependant,
body=body, body=body,
is_body_json=is_body_json,
background_tasks=background_tasks, background_tasks=background_tasks,
response=response, response=response,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
@ -666,6 +667,7 @@ async def solve_dependencies(
) = await request_body_to_args( # body_params checked above ) = await request_body_to_args( # body_params checked above
body_fields=dependant.body_params, body_fields=dependant.body_params,
received_body=body, received_body=body,
is_body_json=is_body_json,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
values.update(body_values) values.update(body_values)
@ -695,6 +697,24 @@ async def solve_dependencies(
) )
def _validate_json_body_as_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
if value is None:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
v_, errors_ = parse_json_field(field, value, values=values, loc=loc)
if isinstance(errors_, ErrorWrapper):
return None, [errors_]
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
return None, new_errors
else:
return v_, []
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List[Any]]:
@ -904,7 +924,8 @@ async def _extract_form_body(
async def request_body_to_args( async def request_body_to_args(
body_fields: List[ModelField], body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]], received_body: Optional[Union[bytes, Dict[str, Any], FormData]],
is_body_json: bool,
embed_body_fields: bool, embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
@ -924,16 +945,28 @@ async def request_body_to_args(
if single_not_embedded_field: if single_not_embedded_field:
loc: Tuple[str, ...] = ("body",) loc: Tuple[str, ...] = ("body",)
if is_body_json:
v_, errors_ = _validate_json_body_as_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
else:
v_, errors_ = _validate_value_with_model_field( v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc field=first_field, value=body_to_process, values=values, loc=loc
) )
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
if is_body_json and isinstance(received_body, bytes):
body_to_process, errors = parse_json(received_body, loc=("body",))
if errors:
return values, errors
for field in body_fields: for field in body_fields:
loc = ("body", field.alias) loc = ("body", field.alias)
value: Optional[Any] = None value: Optional[Any] = None
if body_to_process is not None: if body_to_process is not None:
try: try:
value = body_to_process.get(field.alias) value = body_to_process.get(field.alias) # type: ignore[union-attr]
# If the received body is a list, not a dict # If the received body is a list, not a dict
except AttributeError: except AttributeError:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))

38
fastapi/routing.py

@ -2,7 +2,6 @@ import asyncio
import dataclasses import dataclasses
import email.message import email.message
import inspect import inspect
import json
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum from enum import Enum, IntEnum
from typing import ( from typing import (
@ -25,7 +24,6 @@ from typing import (
from fastapi import params from fastapi import params
from fastapi._compat import ( from fastapi._compat import (
ModelField, ModelField,
Undefined,
_get_model_config, _get_model_config,
_model_dump, _model_dump,
_normalize_errors, _normalize_errors,
@ -243,42 +241,23 @@ def get_request_handler(
async with AsyncExitStack() as file_stack: async with AsyncExitStack() as file_stack:
try: try:
body: Any = None body: Any = None
is_body_json = False
if body_field: if body_field:
if is_body_form: if is_body_form:
body = await request.form() body = await request.form()
file_stack.push_async_callback(body.close) file_stack.push_async_callback(body.close)
else: else:
body_bytes = await request.body() body = await request.body() or None
if body_bytes: if body:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if not content_type_value:
json_body = await request.json()
else:
message = email.message.Message() message = email.message.Message()
message["content-type"] = content_type_value message["content-type"] = request.headers.get(
"content-type", "application/json"
)
if message.get_content_maintype() == "application": if message.get_content_maintype() == "application":
subtype = message.get_content_subtype() subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"): is_body_json = subtype == "json" or subtype.endswith(
json_body = await request.json() "+json"
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
validation_error = RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
],
body=e.doc,
) )
raise validation_error from e
except HTTPException: except HTTPException:
# If a middleware raises an HTTPException, it should be raised again # If a middleware raises an HTTPException, it should be raised again
raise raise
@ -293,6 +272,7 @@ def get_request_handler(
request=request, request=request,
dependant=dependant, dependant=dependant,
body=body, body=body,
is_body_json=is_body_json,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack, async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,

15
tests/test_tutorial/test_body/test_tutorial001.py

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.routing import Request
from ...utils import needs_py310 from ...utils import needs_py310
@ -206,12 +207,10 @@ def test_post_broken_body(client: TestClient):
"detail": [ "detail": [
{ {
"type": "json_invalid", "type": "json_invalid",
"loc": ["body", 1], "loc": ["body"],
"msg": "JSON decode error", "msg": "Invalid JSON: key must be a string at line 1 column 2",
"input": {}, "input": "{some broken json}",
"ctx": { "ctx": {"error": "key must be a string at line 1 column 2"},
"error": "Expecting property name enclosed in double quotes"
},
} }
] ]
} }
@ -221,7 +220,7 @@ def test_post_broken_body(client: TestClient):
"detail": [ "detail": [
{ {
"loc": ["body", 1], "loc": ["body", 1],
"msg": "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)", "msg": "Invalid JSON: Expecting property name enclosed in double quotes",
"type": "value_error.jsondecode", "type": "value_error.jsondecode",
"ctx": { "ctx": {
"msg": "Expecting property name enclosed in double quotes", "msg": "Expecting property name enclosed in double quotes",
@ -383,7 +382,7 @@ def test_wrong_headers(client: TestClient):
def test_other_exceptions(client: TestClient): def test_other_exceptions(client: TestClient):
with patch("json.loads", side_effect=Exception): with patch.object(Request, "body", side_effect=Exception):
response = client.post("/items/", json={"test": "test2"}) response = client.post("/items/", json={"test": "test2"})
assert response.status_code == 400, response.text assert response.status_code == 400, response.text

2
tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py

@ -20,7 +20,7 @@ def test_exception_handler_body_access():
{ {
"type": "list_type", "type": "list_type",
"loc": ["body"], "loc": ["body"],
"msg": "Input should be a valid list", "msg": "Input should be a valid array",
"input": {"numbers": [1, 2, 3]}, "input": {"numbers": [1, 2, 3]},
} }
], ],

28
tests/test_tutorial/test_handling_errors/test_tutorial005.py

@ -19,7 +19,31 @@ def test_post_validation_error():
"input": "XL", "input": "XL",
} }
], ],
"body": {"title": "towel", "size": "XL"}, "body": '{"title":"towel","size":"XL"}',
}
) | IsDict(
{
"detail": [
{
"type": "int_parsing",
"loc": ["body", "size"],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "XL",
}
],
"body": '{"title": "towel", "size": "XL"}',
}
) | IsDict(
# TODO: remove when deprecating Pydantic v1
{
"detail": [
{
"loc": ["body", "size"],
"msg": "value is not a valid integer",
"type": "type_error.integer",
}
],
"body": '{"title":"towel","size":"XL"}',
} }
) | IsDict( ) | IsDict(
# TODO: remove when deprecating Pydantic v1 # TODO: remove when deprecating Pydantic v1
@ -31,7 +55,7 @@ def test_post_validation_error():
"type": "type_error.integer", "type": "type_error.integer",
} }
], ],
"body": {"title": "towel", "size": "XL"}, "body": '{"title": "towel", "size": "XL"}',
} }
) )

Loading…
Cancel
Save