Browse Source

ValidationException.errors() are ErrorDetails

Update the documentation to explain that `RequestValidationError` isn't
literally a subclass since Pydantic V2.
pull/11542/head
Tamir Duberstein 11 months ago
parent
commit
4ecfb90850
Failed to extract signature
  1. 2
      docs/en/docs/tutorial/handling-errors.md
  2. 91
      fastapi/_compat.py
  3. 30
      fastapi/dependencies/utils.py
  4. 12
      fastapi/exceptions.py
  5. 18
      fastapi/routing.py

2
docs/en/docs/tutorial/handling-errors.md

@ -162,7 +162,7 @@ These are technical details that you might skip if it's not important for you no
///
`RequestValidationError` is a sub-class of Pydantic's <a href="https://docs.pydantic.dev/latest/concepts/models/#error-handling" class="external-link" target="_blank">`ValidationError`</a>.
`RequestValidationError` is morally a sub-class of Pydantic's <a href="https://docs.pydantic.dev/latest/concepts/models/#error-handling" class="external-link" target="_blank">`ValidationError`</a>.
**FastAPI** uses it so that, if you use a Pydantic model in `response_model`, and your data has an error, you will see the error in your log.

91
fastapi/_compat.py

@ -4,6 +4,7 @@ from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
@ -23,7 +24,14 @@ from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
from typing_extensions import (
Annotated,
Literal,
TypeAlias,
assert_never,
get_args,
get_origin,
)
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
@ -60,6 +68,7 @@ if PYDANTIC_V2:
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import ErrorDetails as ErrorDetails
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
@ -84,6 +93,9 @@ if PYDANTIC_V2:
class ErrorWrapper(Exception):
pass
# See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/error_wrappers.py#L45-L47.
ErrorList: TypeAlias = Union[Sequence["ErrorList"], ErrorWrapper]
@dataclass
class ModelField:
field_info: FieldInfo
@ -117,22 +129,25 @@ if PYDANTIC_V2:
return Undefined
return self.field_info.get_default(call_default_factory=True)
# See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/fields.py#L850-L852 for the signature.
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
) -> Tuple[Any, Union[ErrorList, Sequence[ErrorDetails], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
errors: List[ErrorDetails] = [
{**err, "loc": loc + err["loc"]}
for err in exc.errors(include_url=False)
]
return None, errors
def serialize(
self,
@ -169,7 +184,13 @@ if PYDANTIC_V2:
) -> Any:
return annotation
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
def _normalize_errors(
errors: Union[ErrorList, Sequence[ErrorDetails]],
) -> List[ErrorDetails]:
assert isinstance(errors, Sequence), type(errors)
for error in errors:
assert not isinstance(error, ErrorWrapper)
assert not isinstance(error, Sequence)
return errors # type: ignore[return-value]
def _model_rebuild(model: Type[BaseModel]) -> None:
@ -267,12 +288,12 @@ if PYDANTIC_V2:
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails:
[error] = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
).errors(include_url=False, include_input=False)
error["input"] = None
return error # type: ignore[return-value]
return error
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
@ -297,6 +318,14 @@ else:
from pydantic.class_validators import ( # type: ignore[no-redef]
Validator as Validator,
)
if TYPE_CHECKING: # pragma: nocover
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorDict as ErrorDetails,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorList as ErrorList,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper,
)
@ -427,18 +456,23 @@ else:
return True
return False
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
def _normalize_errors(
errors: Union[ErrorList, Sequence["ErrorDetails"]],
) -> List["ErrorDetails"]:
use_errors: List[ErrorDetails] = []
if isinstance(errors, ErrorWrapper):
use_errors.extend(
ValidationError( # type: ignore[call-arg]
errors=[errors], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
)
elif isinstance(errors, Sequence):
for error in errors:
assert not isinstance(error, dict)
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
else:
assert_never(errors) # pragma: no cover
return use_errors
def _model_rebuild(model: Type[BaseModel]) -> None:
@ -509,10 +543,10 @@ else:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: Tuple[str, ...]) -> "ErrorDetails":
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
[new_error] = ValidationError([missing_field_error], RequestErrorModel).errors()
return new_error
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
@ -526,17 +560,6 @@ else:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False

30
fastapi/dependencies/utils.py

@ -3,6 +3,7 @@ from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
@ -22,11 +23,10 @@ import anyio
from fastapi import params
from fastapi._compat import (
PYDANTIC_V2,
ErrorWrapper,
ModelField,
RequiredParam,
Undefined,
_regenerate_error_with_loc,
_normalize_errors,
copy_field_info,
create_body_model,
evaluate_forwardref,
@ -46,6 +46,9 @@ from fastapi._compat import (
serialize_sequence_value,
value_is_sequence,
)
if TYPE_CHECKING: # pragma: nocover
from fastapi._compat import ErrorDetails
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
@ -563,7 +566,7 @@ async def solve_generator(
@dataclass
class SolvedDependency:
values: Dict[str, Any]
errors: List[Any]
errors: List["ErrorDetails"]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
@ -582,7 +585,7 @@ async def solve_dependencies(
embed_body_fields: bool,
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
errors: List[ErrorDetails] = []
if response is None:
response = Response()
del response.headers["content-length"]
@ -658,7 +661,8 @@ async def solve_dependencies(
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
for errors_ in (path_errors, query_errors, header_errors, cookie_errors):
errors.extend(errors_)
if dependant.body_params:
(
body_values,
@ -697,17 +701,15 @@ async def solve_dependencies(
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
) -> Tuple[Any, List["ErrorDetails"]]:
if value is None:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper):
return None, [errors_]
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
if errors_ is not None:
new_errors = _normalize_errors(errors_)
return None, new_errors
else:
return v_, []
@ -740,9 +742,9 @@ def _get_multidict_value(
def request_params_to_args(
fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[Any]]:
) -> Tuple[Dict[str, Any], List["ErrorDetails"]]:
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
errors: List[ErrorDetails] = []
if not fields:
return values, errors
@ -877,9 +879,9 @@ async def request_body_to_args(
body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
) -> Tuple[Dict[str, Any], List["ErrorDetails"]]:
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
errors: List[ErrorDetails] = []
assert body_fields, "request_body_to_args() should be called with fields"
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
first_field = body_fields[0]

12
fastapi/exceptions.py

@ -1,5 +1,7 @@
from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union
if TYPE_CHECKING: # pragma: nocover
from fastapi._compat import ErrorDetails
from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException
@ -147,15 +149,15 @@ class FastAPIError(RuntimeError):
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
def __init__(self, errors: Sequence["ErrorDetails"]) -> None:
self._errors = errors
def errors(self) -> Sequence[Any]:
def errors(self) -> Sequence["ErrorDetails"]:
return self._errors
class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body
@ -165,7 +167,7 @@ class WebSocketRequestValidationError(ValidationException):
class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body

18
fastapi/routing.py

@ -153,7 +153,6 @@ async def serialize_response(
is_coroutine: bool = True,
) -> Any:
if field:
errors = []
if not hasattr(field, "serialize"):
# pydantic v1
response_content = _prepare_response_content(
@ -163,15 +162,11 @@ async def serialize_response(
exclude_none=exclude_none,
)
if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",))
value, errors = field.validate(response_content, {}, loc=("response",))
else:
value, errors_ = await run_in_threadpool(
value, errors = await run_in_threadpool(
field.validate, response_content, {}, loc=("response",)
)
if isinstance(errors_, list):
errors.extend(errors_)
elif errors_:
errors.append(errors_)
if errors:
raise ResponseValidationError(
errors=_normalize_errors(errors), body=response_content
@ -286,7 +281,6 @@ def get_request_handler(
status_code=400, detail="There was an error parsing the body"
)
raise http_error from e
errors: List[Any] = []
async with AsyncExitStack() as async_exit_stack:
solved_result = await solve_dependencies(
request=request,
@ -340,9 +334,7 @@ def get_request_handler(
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
)
validation_error = RequestValidationError(errors, body=body)
raise validation_error
if response is None:
raise FastAPIError(
@ -376,9 +368,7 @@ def get_websocket_app(
embed_body_fields=embed_body_fields,
)
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
)
raise WebSocketRequestValidationError(solved_result.errors)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values)

Loading…
Cancel
Save