diff --git a/docs/en/docs/tutorial/handling-errors.md b/docs/en/docs/tutorial/handling-errors.md index 4d969747f..1dd6a59e5 100644 --- a/docs/en/docs/tutorial/handling-errors.md +++ b/docs/en/docs/tutorial/handling-errors.md @@ -162,7 +162,7 @@ These are technical details that you might skip if it's not important for you no /// -`RequestValidationError` is a sub-class of Pydantic's `ValidationError`. +`RequestValidationError` is morally a sub-class of Pydantic's `ValidationError`. **FastAPI** uses it so that, if you use a Pydantic model in `response_model`, and your data has an error, you will see the error in your log. diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c07e4a3b0..24433381a 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, is_dataclass from enum import Enum from functools import lru_cache from typing import ( + TYPE_CHECKING, Any, Callable, Deque, @@ -23,7 +24,14 @@ from fastapi.types import IncEx, ModelNameMap, UnionType from pydantic import BaseModel, create_model from pydantic.version import VERSION as PYDANTIC_VERSION from starlette.datastructures import UploadFile -from typing_extensions import Annotated, Literal, get_args, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + assert_never, + get_args, + get_origin, +) PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 @@ -60,6 +68,7 @@ if PYDANTIC_V2: from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic_core import CoreSchema as CoreSchema + from pydantic_core import ErrorDetails as ErrorDetails from pydantic_core import PydanticUndefined, PydanticUndefinedType from pydantic_core import Url as Url @@ -69,7 +78,7 @@ if PYDANTIC_V2: ) except ImportError: # pragma: no cover from pydantic_core.core_schema import ( - general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 + general_plain_validator_function as with_info_plain_validator_function, ) RequiredParam = PydanticUndefined @@ -84,6 +93,9 @@ if PYDANTIC_V2: class ErrorWrapper(Exception): pass + # See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/error_wrappers.py#L45-L47. + ErrorList: TypeAlias = Union[Sequence["ErrorList"], ErrorWrapper] + @dataclass class ModelField: field_info: FieldInfo @@ -117,22 +129,25 @@ if PYDANTIC_V2: return Undefined return self.field_info.get_default(call_default_factory=True) + # See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/fields.py#L850-L852 for the signature. def validate( self, value: Any, values: Dict[str, Any] = {}, # noqa: B006 *, loc: Tuple[Union[int, str], ...] = (), - ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + ) -> Tuple[Any, Union[ErrorList, Sequence[ErrorDetails], None]]: try: return ( self._type_adapter.validate_python(value, from_attributes=True), None, ) except ValidationError as exc: - return None, _regenerate_error_with_loc( - errors=exc.errors(include_url=False), loc_prefix=loc - ) + errors: List[ErrorDetails] = [ + {**err, "loc": loc + err["loc"]} + for err in exc.errors(include_url=False) + ] + return None, errors def serialize( self, @@ -169,7 +184,13 @@ if PYDANTIC_V2: ) -> Any: return annotation - def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: + def _normalize_errors( + errors: Union[ErrorList, Sequence[ErrorDetails]], + ) -> List[ErrorDetails]: + assert isinstance(errors, Sequence), type(errors) + for error in errors: + assert not isinstance(error, ErrorWrapper) + assert not isinstance(error, Sequence) return errors # type: ignore[return-value] def _model_rebuild(model: Type[BaseModel]) -> None: @@ -267,12 +288,12 @@ if PYDANTIC_V2: assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] - def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: - error = ValidationError.from_exception_data( + def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails: + [error] = ValidationError.from_exception_data( "Field required", [{"type": "missing", "loc": loc, "input": {}}] - ).errors(include_url=False)[0] + ).errors(include_url=False, include_input=False) error["input"] = None - return error # type: ignore[return-value] + return error def create_body_model( *, fields: Sequence[ModelField], model_name: str @@ -291,14 +312,22 @@ else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 from pydantic import ( # type: ignore[assignment] - BaseConfig as BaseConfig, # noqa: F401 + BaseConfig as BaseConfig, ) - from pydantic import ValidationError as ValidationError # noqa: F401 + from pydantic import ValidationError as ValidationError from pydantic.class_validators import ( # type: ignore[no-redef] - Validator as Validator, # noqa: F401 + Validator as Validator, + ) + + if TYPE_CHECKING: # pragma: nocover + from pydantic.error_wrappers import ( # type: ignore[no-redef] + ErrorDict as ErrorDetails, + ) + from pydantic.error_wrappers import ( # type: ignore[no-redef] + ErrorList as ErrorList, ) from pydantic.error_wrappers import ( # type: ignore[no-redef] - ErrorWrapper as ErrorWrapper, # noqa: F401 + ErrorWrapper as ErrorWrapper, ) from pydantic.errors import MissingError from pydantic.fields import ( # type: ignore[attr-defined] @@ -312,7 +341,7 @@ else: ) from pydantic.fields import FieldInfo as FieldInfo from pydantic.fields import ( # type: ignore[no-redef,attr-defined] - ModelField as ModelField, # noqa: F401 + ModelField as ModelField, ) # Keeping old "Required" functionality from Pydantic V1, without @@ -322,7 +351,7 @@ else: Undefined as Undefined, ) from pydantic.fields import ( # type: ignore[no-redef, attr-defined] - UndefinedType as UndefinedType, # noqa: F401 + UndefinedType as UndefinedType, ) from pydantic.schema import ( field_schema, @@ -330,14 +359,14 @@ else: get_model_name_map, model_process_schema, ) - from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401 + from pydantic.schema import ( # type: ignore[no-redef] get_annotation_from_field_info as get_annotation_from_field_info, ) from pydantic.typing import ( # type: ignore[no-redef] - evaluate_forwardref as evaluate_forwardref, # noqa: F401 + evaluate_forwardref as evaluate_forwardref, ) from pydantic.utils import ( # type: ignore[no-redef] - lenient_issubclass as lenient_issubclass, # noqa: F401 + lenient_issubclass as lenient_issubclass, ) GetJsonSchemaHandler = Any # type: ignore[assignment,misc] @@ -427,18 +456,23 @@ else: return True return False - def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: - use_errors: List[Any] = [] - for error in errors: - if isinstance(error, ErrorWrapper): - new_errors = ValidationError( # type: ignore[call-arg] - errors=[error], model=RequestErrorModel + def _normalize_errors( + errors: Union[ErrorList, Sequence["ErrorDetails"]], + ) -> List["ErrorDetails"]: + use_errors: List[ErrorDetails] = [] + if isinstance(errors, ErrorWrapper): + use_errors.extend( + ValidationError( # type: ignore[call-arg] + errors=[errors], model=RequestErrorModel ).errors() - use_errors.extend(new_errors) - elif isinstance(error, list): + ) + elif isinstance(errors, Sequence): + for error in errors: + assert not isinstance(error, dict) use_errors.extend(_normalize_errors(error)) - else: - use_errors.append(error) + return use_errors + else: + assert_never(errors) # pragma: no cover return use_errors def _model_rebuild(model: Type[BaseModel]) -> None: @@ -509,10 +543,10 @@ else: def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined] - def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + def get_missing_field_error(loc: Tuple[str, ...]) -> "ErrorDetails": missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg] - new_error = ValidationError([missing_field_error], RequestErrorModel) - return new_error.errors()[0] # type: ignore[return-value] + [new_error] = ValidationError([missing_field_error], RequestErrorModel).errors() + return new_error def create_body_model( *, fields: Sequence[ModelField], model_name: str @@ -526,17 +560,6 @@ else: return list(model.__fields__.values()) # type: ignore[attr-defined] -def _regenerate_error_with_loc( - *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] -) -> List[Dict[str, Any]]: - updated_loc_errors: List[Any] = [ - {**err, "loc": loc_prefix + err.get("loc", ())} - for err in _normalize_errors(errors) - ] - - return updated_loc_errors - - def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: if lenient_issubclass(annotation, (str, bytes)): return False diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..d5247c3e7 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -3,6 +3,7 @@ from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, Coroutine, @@ -22,11 +23,10 @@ import anyio from fastapi import params from fastapi._compat import ( PYDANTIC_V2, - ErrorWrapper, ModelField, RequiredParam, Undefined, - _regenerate_error_with_loc, + _normalize_errors, copy_field_info, create_body_model, evaluate_forwardref, @@ -46,6 +46,9 @@ from fastapi._compat import ( serialize_sequence_value, value_is_sequence, ) + +if TYPE_CHECKING: # pragma: nocover + from fastapi._compat import ErrorDetails from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, @@ -563,7 +566,7 @@ async def solve_generator( @dataclass class SolvedDependency: values: Dict[str, Any] - errors: List[Any] + errors: List["ErrorDetails"] background_tasks: Optional[StarletteBackgroundTasks] response: Response dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] @@ -582,7 +585,7 @@ async def solve_dependencies( embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} - errors: List[Any] = [] + errors: List[ErrorDetails] = [] if response is None: response = Response() del response.headers["content-length"] @@ -658,7 +661,8 @@ async def solve_dependencies( values.update(query_values) values.update(header_values) values.update(cookie_values) - errors += path_errors + query_errors + header_errors + cookie_errors + for errors_ in (path_errors, query_errors, header_errors, cookie_errors): + errors.extend(errors_) if dependant.body_params: ( body_values, @@ -697,17 +701,15 @@ async def solve_dependencies( def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] -) -> Tuple[Any, List[Any]]: +) -> Tuple[Any, List["ErrorDetails"]]: if value is None: if field.required: return None, [get_missing_field_error(loc=loc)] else: return deepcopy(field.default), [] v_, errors_ = field.validate(value, values, loc=loc) - if isinstance(errors_, ErrorWrapper): - return None, [errors_] - elif isinstance(errors_, list): - new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + if errors_ is not None: + new_errors = _normalize_errors(errors_) return None, new_errors else: return v_, [] @@ -740,9 +742,9 @@ def _get_multidict_value( def request_params_to_args( fields: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], -) -> Tuple[Dict[str, Any], List[Any]]: +) -> Tuple[Dict[str, Any], List["ErrorDetails"]]: values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] + errors: List[ErrorDetails] = [] if not fields: return values, errors @@ -885,9 +887,9 @@ async def request_body_to_args( body_fields: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], embed_body_fields: bool, -) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: +) -> Tuple[Dict[str, Any], List["ErrorDetails"]]: values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] + errors: List[ErrorDetails] = [] assert body_fields, "request_body_to_args() should be called with fields" single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields first_field = body_fields[0] diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86..ca41bdac1 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -1,5 +1,7 @@ -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union +if TYPE_CHECKING: # pragma: nocover + from fastapi._compat import ErrorDetails from pydantic import BaseModel, create_model from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import WebSocketException as StarletteWebSocketException @@ -147,15 +149,15 @@ class FastAPIError(RuntimeError): class ValidationException(Exception): - def __init__(self, errors: Sequence[Any]) -> None: + def __init__(self, errors: Sequence["ErrorDetails"]) -> None: self._errors = errors - def errors(self) -> Sequence[Any]: + def errors(self) -> Sequence["ErrorDetails"]: return self._errors class RequestValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None: super().__init__(errors) self.body = body @@ -165,7 +167,7 @@ class WebSocketRequestValidationError(ValidationException): class ResponseValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None: super().__init__(errors) self.body = body diff --git a/fastapi/routing.py b/fastapi/routing.py index 457481e32..68d9d4988 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -153,7 +153,6 @@ async def serialize_response( is_coroutine: bool = True, ) -> Any: if field: - errors = [] if not hasattr(field, "serialize"): # pydantic v1 response_content = _prepare_response_content( @@ -163,15 +162,11 @@ async def serialize_response( exclude_none=exclude_none, ) if is_coroutine: - value, errors_ = field.validate(response_content, {}, loc=("response",)) + value, errors = field.validate(response_content, {}, loc=("response",)) else: - value, errors_ = await run_in_threadpool( + value, errors = await run_in_threadpool( field.validate, response_content, {}, loc=("response",) ) - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) if errors: raise ResponseValidationError( errors=_normalize_errors(errors), body=response_content @@ -286,7 +281,6 @@ def get_request_handler( status_code=400, detail="There was an error parsing the body" ) raise http_error from e - errors: List[Any] = [] async with AsyncExitStack() as async_exit_stack: solved_result = await solve_dependencies( request=request, @@ -340,9 +334,7 @@ def get_request_handler( response.body = b"" response.headers.raw.extend(solved_result.response.headers.raw) if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body - ) + validation_error = RequestValidationError(errors, body=body) raise validation_error if response is None: raise FastAPIError( @@ -376,9 +368,7 @@ def get_websocket_app( embed_body_fields=embed_body_fields, ) if solved_result.errors: - raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) - ) + raise WebSocketRequestValidationError(solved_result.errors) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**solved_result.values)