diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 227ad837d..bd34e0154 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -1,3 +1,4 @@ +import json from collections import deque from copy import copy from dataclasses import dataclass, is_dataclass @@ -135,6 +136,22 @@ if PYDANTIC_V2: errors=exc.errors(include_url=False), loc_prefix=loc ) + def validate_json( + self, + value: Any, + *, + loc: Tuple[Union[int, str], ...] = (), + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + try: + return ( + self._type_adapter.validate_json(value), + None, + ) + except ValidationError as exc: + return None, _regenerate_error_with_loc( + errors=exc.errors(include_url=False), loc_prefix=loc + ) + def serialize( self, value: Any, @@ -292,6 +309,18 @@ if PYDANTIC_V2: for name, field_info in model.model_fields.items() ] + def parse_json_field( + field: ModelField, + data: bytes, + values: Dict[str, Any], + loc: Tuple[str, ...] = (), + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + value, errors = field.validate_json(data) + if isinstance(errors, list): + new_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=loc) + return None, new_errors + return value, errors + else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 @@ -530,6 +559,17 @@ else: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: return list(model.__fields__.values()) # type: ignore[attr-defined] + def parse_json_field( + field: ModelField, + data: bytes, + values: Dict[str, Any], + loc: Tuple[str, ...] = (), + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + parsed_value, errors = parse_json(data, loc) + if errors: + return None, errors + return field.validate(parsed_value, values, loc=loc) + def _regenerate_error_with_loc( *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] @@ -662,3 +702,26 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool: @lru_cache def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]: return get_model_fields(model) + + +def parse_json( + data: bytes, + loc: Tuple[str, ...], +) -> Tuple[Any, List[Dict[str, Any]]]: + try: + return json.loads(data), [] + except json.JSONDecodeError as e: + return None, [ + { + "type": "value_error.jsondecode", + "loc": loc + (e.pos,), + "msg": f"Invalid JSON: {e.msg}", + "ctx": { + "msg": e.msg, + "doc": data, + "pos": e.pos, + "lineno": e.lineno, + "colno": e.colno, + }, + } + ] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..bc1f0edba 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -42,15 +42,14 @@ from fastapi._compat import ( is_uploadfile_or_nonable_uploadfile_annotation, is_uploadfile_sequence_annotation, lenient_issubclass, + parse_json, + parse_json_field, sequence_types, serialize_sequence_value, value_is_sequence, ) from fastapi.background import BackgroundTasks -from fastapi.concurrency import ( - asynccontextmanager, - contextmanager_in_threadpool, -) +from fastapi.concurrency import asynccontextmanager, contextmanager_in_threadpool from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger from fastapi.security.base import SecurityBase @@ -573,7 +572,8 @@ async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, - body: Optional[Union[Dict[str, Any], FormData]] = None, + body: Optional[Union[bytes, Dict[str, Any], FormData]] = None, + is_body_json: bool = False, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, @@ -616,6 +616,7 @@ async def solve_dependencies( request=request, dependant=use_sub_dependant, body=body, + is_body_json=is_body_json, background_tasks=background_tasks, response=response, dependency_overrides_provider=dependency_overrides_provider, @@ -666,6 +667,7 @@ async def solve_dependencies( ) = await request_body_to_args( # body_params checked above body_fields=dependant.body_params, received_body=body, + is_body_json=is_body_json, embed_body_fields=embed_body_fields, ) values.update(body_values) @@ -695,6 +697,24 @@ async def solve_dependencies( ) +def _validate_json_body_as_model_field( + *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] +) -> Tuple[Any, List[Any]]: + if value is None: + if field.required: + return None, [get_missing_field_error(loc=loc)] + else: + return deepcopy(field.default), [] + v_, errors_ = parse_json_field(field, value, values=values, loc=loc) + if isinstance(errors_, ErrorWrapper): + return None, [errors_] + elif isinstance(errors_, list): + new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + return None, new_errors + else: + return v_, [] + + def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] ) -> Tuple[Any, List[Any]]: @@ -904,7 +924,8 @@ async def _extract_form_body( async def request_body_to_args( body_fields: List[ModelField], - received_body: Optional[Union[Dict[str, Any], FormData]], + received_body: Optional[Union[bytes, Dict[str, Any], FormData]], + is_body_json: bool, embed_body_fields: bool, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: values: Dict[str, Any] = {} @@ -924,16 +945,28 @@ async def request_body_to_args( if single_not_embedded_field: loc: Tuple[str, ...] = ("body",) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=body_to_process, values=values, loc=loc - ) + if is_body_json: + v_, errors_ = _validate_json_body_as_model_field( + field=first_field, value=body_to_process, values=values, loc=loc + ) + else: + v_, errors_ = _validate_value_with_model_field( + field=first_field, value=body_to_process, values=values, loc=loc + ) return {first_field.name: v_}, errors_ + + if is_body_json and isinstance(received_body, bytes): + body_to_process, errors = parse_json(received_body, loc=("body",)) + + if errors: + return values, errors + for field in body_fields: loc = ("body", field.alias) value: Optional[Any] = None if body_to_process is not None: try: - value = body_to_process.get(field.alias) + value = body_to_process.get(field.alias) # type: ignore[union-attr] # If the received body is a list, not a dict except AttributeError: errors.append(get_missing_field_error(loc)) diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..d6c540535 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -2,7 +2,6 @@ import asyncio import dataclasses import email.message import inspect -import json from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum, IntEnum from typing import ( @@ -25,7 +24,6 @@ from typing import ( from fastapi import params from fastapi._compat import ( ModelField, - Undefined, _get_model_config, _model_dump, _normalize_errors, @@ -243,42 +241,23 @@ def get_request_handler( async with AsyncExitStack() as file_stack: try: body: Any = None + is_body_json = False if body_field: if is_body_form: body = await request.form() file_stack.push_async_callback(body.close) else: - body_bytes = await request.body() - if body_bytes: - json_body: Any = Undefined - content_type_value = request.headers.get("content-type") - if not content_type_value: - json_body = await request.json() - else: - message = email.message.Message() - message["content-type"] = content_type_value - if message.get_content_maintype() == "application": - subtype = message.get_content_subtype() - if subtype == "json" or subtype.endswith("+json"): - json_body = await request.json() - if json_body != Undefined: - body = json_body - else: - body = body_bytes - except json.JSONDecodeError as e: - validation_error = RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - } - ], - body=e.doc, - ) - raise validation_error from e + body = await request.body() or None + if body: + message = email.message.Message() + message["content-type"] = request.headers.get( + "content-type", "application/json" + ) + if message.get_content_maintype() == "application": + subtype = message.get_content_subtype() + is_body_json = subtype == "json" or subtype.endswith( + "+json" + ) except HTTPException: # If a middleware raises an HTTPException, it should be raised again raise @@ -293,6 +272,7 @@ def get_request_handler( request=request, dependant=dependant, body=body, + is_body_json=is_body_json, dependency_overrides_provider=dependency_overrides_provider, async_exit_stack=async_exit_stack, embed_body_fields=embed_body_fields, diff --git a/tests/test_tutorial/test_body/test_tutorial001.py b/tests/test_tutorial/test_body/test_tutorial001.py index f8b5aee8d..8211871c0 100644 --- a/tests/test_tutorial/test_body/test_tutorial001.py +++ b/tests/test_tutorial/test_body/test_tutorial001.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from dirty_equals import IsDict from fastapi.testclient import TestClient +from starlette.routing import Request from ...utils import needs_py310 @@ -206,12 +207,10 @@ def test_post_broken_body(client: TestClient): "detail": [ { "type": "json_invalid", - "loc": ["body", 1], - "msg": "JSON decode error", - "input": {}, - "ctx": { - "error": "Expecting property name enclosed in double quotes" - }, + "loc": ["body"], + "msg": "Invalid JSON: key must be a string at line 1 column 2", + "input": "{some broken json}", + "ctx": {"error": "key must be a string at line 1 column 2"}, } ] } @@ -221,7 +220,7 @@ def test_post_broken_body(client: TestClient): "detail": [ { "loc": ["body", 1], - "msg": "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)", + "msg": "Invalid JSON: Expecting property name enclosed in double quotes", "type": "value_error.jsondecode", "ctx": { "msg": "Expecting property name enclosed in double quotes", @@ -383,7 +382,7 @@ def test_wrong_headers(client: TestClient): def test_other_exceptions(client: TestClient): - with patch("json.loads", side_effect=Exception): + with patch.object(Request, "body", side_effect=Exception): response = client.post("/items/", json={"test": "test2"}) assert response.status_code == 400, response.text diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py index 647f1c5dd..b76cb43ec 100644 --- a/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py +++ b/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py @@ -20,7 +20,7 @@ def test_exception_handler_body_access(): { "type": "list_type", "loc": ["body"], - "msg": "Input should be a valid list", + "msg": "Input should be a valid array", "input": {"numbers": [1, 2, 3]}, } ], diff --git a/tests/test_tutorial/test_handling_errors/test_tutorial005.py b/tests/test_tutorial/test_handling_errors/test_tutorial005.py index 581b2e4c7..06b0eb560 100644 --- a/tests/test_tutorial/test_handling_errors/test_tutorial005.py +++ b/tests/test_tutorial/test_handling_errors/test_tutorial005.py @@ -19,7 +19,31 @@ def test_post_validation_error(): "input": "XL", } ], - "body": {"title": "towel", "size": "XL"}, + "body": '{"title":"towel","size":"XL"}', + } + ) | IsDict( + { + "detail": [ + { + "type": "int_parsing", + "loc": ["body", "size"], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "XL", + } + ], + "body": '{"title": "towel", "size": "XL"}', + } + ) | IsDict( + # TODO: remove when deprecating Pydantic v1 + { + "detail": [ + { + "loc": ["body", "size"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ], + "body": '{"title":"towel","size":"XL"}', } ) | IsDict( # TODO: remove when deprecating Pydantic v1 @@ -31,7 +55,7 @@ def test_post_validation_error(): "type": "type_error.integer", } ], - "body": {"title": "towel", "size": "XL"}, + "body": '{"title": "towel", "size": "XL"}', } )