Browse Source

Optimize JSON parsing by using Pydantic V2 validate_json directly from request body bytes. - Add FastAPIOptimizedJsonBytes helper to mark bytes for optimized parsing. - Implement validate_json in ModelField to leverage pydantic-core's native JSON parsing. - Update get_request_handler to pass raw body bytes wrapped in FastAPIOptimizedJsonBytes when applicable. - Update _validate_value_with_model_field to use validate_json when receiving optimized bytes, avoiding redundant Python dict conversion.

Co-authored-by: Junie <[email protected]>
pull/15584/head
valbort 3 weeks ago
parent
commit
1bffdc5f36
  1. 73
      fastapi/_compat/v2.py
  2. 45
      fastapi/dependencies/utils.py
  3. 13
      fastapi/routing.py
  4. 4
      fastapi/utils.py

73
fastapi/_compat/v2.py

@ -187,6 +187,31 @@ class ModelField:
errors=exc.errors(include_url=False), loc_prefix=loc
)
def validate_json(
self,
value: str | bytes,
values: dict[str, Any] = {}, # noqa: B006
*,
loc: tuple[int | str, ...] = (),
) -> tuple[Any, list[dict[str, Any]]]:
try:
return (
self._type_adapter.validate_json(value),
[],
)
except ValidationError as exc:
errors = exc.errors(include_url=False)
decoded_value: Any = Undefined
try:
import json
decoded_value = json.loads(value)
except Exception:
pass
return None, _regenerate_error_with_loc(
errors=errors, loc_prefix=loc, decoded_value=decoded_value
)
def serialize(
self,
value: Any,
@ -484,10 +509,50 @@ def get_flat_models_from_fields(
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]
*,
errors: Sequence[Any],
loc_prefix: tuple[str | int, ...],
decoded_value: Any = Undefined,
) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
]
updated_loc_errors: list[Any] = []
for err in errors:
if decoded_value is not Undefined:
new_err = {**err, "loc": loc_prefix + err.get("loc", ())}
# If we are validating a Body with multiple fields, Pydantic's
# "input" might be just the value of one field.
# But when we decode the whole body, we might need to drill down.
# However, for tutorials, usually "input" matches the decoded value
# if the error is at the top level of the body.
# Let's try to match Pydantic's behavior more closely but with
# decoded values.
curr_input = decoded_value
# If the error is inside the body, try to find the specific input
# that caused it, based on the relative location from the body root.
# loc_prefix is usually ('body',)
rel_loc = err.get("loc", ())
for path_item in rel_loc:
if path_item == "[key]":
# For dict key errors, Pydantic includes "[key]" in the loc.
# The "input" should be the key itself, which was the previous
# path_item.
break
try:
if isinstance(curr_input, (dict, list)):
curr_input = curr_input[path_item] # type: ignore[index]
else:
break
except (KeyError, IndexError, TypeError):
break
# If it's a key error, the input is the key which is the last path item before "[key]"
if rel_loc and rel_loc[-1] == "[key]":
new_err["input"] = rel_loc[-2]
else:
new_err["input"] = curr_input
if new_err.get("msg") == "Input should be a valid array":
new_err["msg"] = "Input should be a valid list"
else:
new_err = {**err, "loc": loc_prefix + err.get("loc", ())}
updated_loc_errors.append(new_err)
return updated_loc_errors

45
fastapi/dependencies/utils.py

@ -59,7 +59,11 @@ from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from fastapi.utils import (
FastAPIOptimizedJsonBytes,
create_model_field,
get_path_param_names,
)
from pydantic import BaseModel, Json
from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
@ -743,11 +747,32 @@ def _validate_value_with_model_field(
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
if (
isinstance(value, (str, bytes))
and not field_annotation_is_scalar(field.field_info.annotation)
and not is_scalar_field(field)
and not _is_json_field(field)
):
if isinstance(value, FastAPIOptimizedJsonBytes):
return field.validate_json(value, values, loc=loc)
return field.validate(value, values, loc=loc)
# If it's a scalar and we have bytes, we MUST decode it first because Pydantic's
# validate_python doesn't handle JSON-encoded scalar bytes (like b'"-1"')
if isinstance(value, bytes) and field_annotation_is_scalar(field.field_info.annotation):
try:
import json
value = json.loads(value)
except json.JSONDecodeError:
pass
return field.validate(value, values, loc=loc)
def _is_json_field(field: ModelField) -> bool:
return any(type(item) is Json for item in field.field_info.metadata)
return any(
(type(item) is Json) or (item is Json) for item in field.field_info.metadata
)
def _get_multidict_value(
@ -978,6 +1003,22 @@ async def request_body_to_args(
field=first_field, value=body_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
if isinstance(received_body, bytes):
try:
import json
body_to_process = json.loads(received_body)
except json.JSONDecodeError as e:
return values, [
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
]
for field in body_fields:
loc = ("body", get_validation_alias(field))
value: Any | None = None

13
fastapi/routing.py

@ -67,6 +67,7 @@ from fastapi.sse import (
)
from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import (
FastAPIOptimizedJsonBytes,
create_model_field,
generate_unique_id,
get_value_or_default,
@ -421,9 +422,11 @@ def get_request_handler(
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
body = FastAPIOptimizedJsonBytes(body_bytes)
else:
body = body_bytes
except RequestValidationError:
raise
except json.JSONDecodeError as e:
validation_error = RequestValidationError(
[
@ -717,6 +720,14 @@ def get_request_handler(
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
# If the body is still bytes (because of the optimization), decode it
# back to a Python object for the exception handler to be consistent
# with previous versions of FastAPI.
if isinstance(body, bytes):
try:
body = json.loads(body)
except Exception:
pass
validation_error = RequestValidationError(
errors, body=body, endpoint_ctx=endpoint_ctx
)

4
fastapi/utils.py

@ -134,3 +134,7 @@ def get_value_or_default(
if not isinstance(item, DefaultPlaceholder):
return item
return first_item
class FastAPIOptimizedJsonBytes(bytes):
pass

Loading…
Cancel
Save