diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 227ad837d..8457b245b 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, is_dataclass from enum import Enum from functools import lru_cache from typing import ( + TYPE_CHECKING, Any, Callable, Deque, @@ -24,7 +25,14 @@ from fastapi.types import IncEx, ModelNameMap, UnionType from pydantic import BaseModel, create_model from pydantic.version import VERSION as PYDANTIC_VERSION from starlette.datastructures import UploadFile -from typing_extensions import Annotated, Literal, get_args, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + assert_never, + get_args, + get_origin, +) PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 @@ -61,6 +69,7 @@ if PYDANTIC_V2: from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic_core import CoreSchema as CoreSchema + from pydantic_core import ErrorDetails as ErrorDetails from pydantic_core import PydanticUndefined, PydanticUndefinedType from pydantic_core import Url as Url @@ -70,7 +79,7 @@ if PYDANTIC_V2: ) except ImportError: # pragma: no cover from pydantic_core.core_schema import ( - general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 + general_plain_validator_function as with_info_plain_validator_function, ) RequiredParam = PydanticUndefined @@ -85,6 +94,9 @@ if PYDANTIC_V2: class ErrorWrapper(Exception): pass + # See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/error_wrappers.py#L45-L47. + ErrorList: TypeAlias = Union[Sequence["ErrorList"], ErrorWrapper] + @dataclass class ModelField: field_info: FieldInfo @@ -118,22 +130,25 @@ if PYDANTIC_V2: return Undefined return self.field_info.get_default(call_default_factory=True) + # See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/fields.py#L850-L852 for the signature. def validate( self, value: Any, values: Dict[str, Any] = {}, # noqa: B006 *, loc: Tuple[Union[int, str], ...] = (), - ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + ) -> Tuple[Any, Union[ErrorList, Sequence[ErrorDetails], None]]: try: return ( self._type_adapter.validate_python(value, from_attributes=True), None, ) except ValidationError as exc: - return None, _regenerate_error_with_loc( - errors=exc.errors(include_url=False), loc_prefix=loc - ) + errors: List[ErrorDetails] = [ + {**err, "loc": loc + err["loc"]} + for err in exc.errors(include_url=False) + ] + return None, errors def serialize( self, @@ -170,7 +185,13 @@ if PYDANTIC_V2: ) -> Any: return annotation - def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: + def _normalize_errors( + errors: Union[ErrorList, Sequence[ErrorDetails]], + ) -> List[ErrorDetails]: + assert isinstance(errors, Sequence), type(errors) + for error in errors: + assert not isinstance(error, ErrorWrapper) + assert not isinstance(error, Sequence) return errors # type: ignore[return-value] def _model_rebuild(model: Type[BaseModel]) -> None: @@ -272,12 +293,12 @@ if PYDANTIC_V2: assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] - def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: - error = ValidationError.from_exception_data( + def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails: + [error] = ValidationError.from_exception_data( "Field required", [{"type": "missing", "loc": loc, "input": {}}] - ).errors(include_url=False)[0] + ).errors(include_url=False, include_input=False) error["input"] = None - return error # type: ignore[return-value] + return error def create_body_model( *, fields: Sequence[ModelField], model_name: str @@ -296,14 +317,22 @@ else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 from pydantic import ( # type: ignore[assignment] - BaseConfig as BaseConfig, # noqa: F401 + BaseConfig as BaseConfig, ) - from pydantic import ValidationError as ValidationError # noqa: F401 + from pydantic import ValidationError as ValidationError from pydantic.class_validators import ( # type: ignore[no-redef] - Validator as Validator, # noqa: F401 + Validator as Validator, + ) + + if TYPE_CHECKING: # pragma: nocover + from pydantic.error_wrappers import ( # type: ignore[no-redef] + ErrorDict as ErrorDetails, + ) + from pydantic.error_wrappers import ( # type: ignore[no-redef] + ErrorList as ErrorList, ) from pydantic.error_wrappers import ( # type: ignore[no-redef] - ErrorWrapper as ErrorWrapper, # noqa: F401 + ErrorWrapper as ErrorWrapper, ) from pydantic.errors import MissingError from pydantic.fields import ( # type: ignore[attr-defined] @@ -317,7 +346,7 @@ else: ) from pydantic.fields import FieldInfo as FieldInfo from pydantic.fields import ( # type: ignore[no-redef,attr-defined] - ModelField as ModelField, # noqa: F401 + ModelField as ModelField, ) # Keeping old "Required" functionality from Pydantic V1, without @@ -327,7 +356,7 @@ else: Undefined as Undefined, ) from pydantic.fields import ( # type: ignore[no-redef, attr-defined] - UndefinedType as UndefinedType, # noqa: F401 + UndefinedType as UndefinedType, ) from pydantic.schema import ( field_schema, @@ -335,14 +364,14 @@ else: get_model_name_map, model_process_schema, ) - from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401 + from pydantic.schema import ( # type: ignore[no-redef] get_annotation_from_field_info as get_annotation_from_field_info, ) from pydantic.typing import ( # type: ignore[no-redef] - evaluate_forwardref as evaluate_forwardref, # noqa: F401 + evaluate_forwardref as evaluate_forwardref, ) from pydantic.utils import ( # type: ignore[no-redef] - lenient_issubclass as lenient_issubclass, # noqa: F401 + lenient_issubclass as lenient_issubclass, ) GetJsonSchemaHandler = Any # type: ignore[assignment,misc] @@ -432,18 +461,23 @@ else: return True return False - def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: - use_errors: List[Any] = [] - for error in errors: - if isinstance(error, ErrorWrapper): - new_errors = ValidationError( # type: ignore[call-arg] - errors=[error], model=RequestErrorModel + def _normalize_errors( + errors: Union[ErrorList, Sequence["ErrorDetails"]], + ) -> List["ErrorDetails"]: + use_errors: List[ErrorDetails] = [] + if isinstance(errors, ErrorWrapper): + use_errors.extend( + ValidationError( # type: ignore[call-arg] + errors=[errors], model=RequestErrorModel ).errors() - use_errors.extend(new_errors) - elif isinstance(error, list): + ) + elif isinstance(errors, Sequence): + for error in errors: + assert not isinstance(error, dict) use_errors.extend(_normalize_errors(error)) - else: - use_errors.append(error) + return use_errors + else: + assert_never(errors) # pragma: no cover return use_errors def _model_rebuild(model: Type[BaseModel]) -> None: @@ -514,10 +548,10 @@ else: def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined] - def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + def get_missing_field_error(loc: Tuple[str, ...]) -> "ErrorDetails": missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg] - new_error = ValidationError([missing_field_error], RequestErrorModel) - return new_error.errors()[0] # type: ignore[return-value] + [new_error] = ValidationError([missing_field_error], RequestErrorModel).errors() + return new_error def create_body_model( *, fields: Sequence[ModelField], model_name: str @@ -531,17 +565,6 @@ else: return list(model.__fields__.values()) # type: ignore[attr-defined] -def _regenerate_error_with_loc( - *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] -) -> List[Dict[str, Any]]: - updated_loc_errors: List[Any] = [ - {**err, "loc": loc_prefix + err.get("loc", ())} - for err in _normalize_errors(errors) - ] - - return updated_loc_errors - - def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: if lenient_issubclass(annotation, (str, bytes)): return False diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..5e26c8dd8 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -3,6 +3,7 @@ from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, Coroutine, @@ -22,11 +23,10 @@ import anyio from fastapi import params from fastapi._compat import ( PYDANTIC_V2, - ErrorWrapper, ModelField, RequiredParam, Undefined, - _regenerate_error_with_loc, + _normalize_errors, copy_field_info, create_body_model, evaluate_forwardref, @@ -46,6 +46,9 @@ from fastapi._compat import ( serialize_sequence_value, value_is_sequence, ) + +if TYPE_CHECKING: # pragma: nocover + from fastapi._compat import ErrorDetails from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, @@ -563,7 +566,7 @@ async def solve_generator( @dataclass class SolvedDependency: values: Dict[str, Any] - errors: List[Any] + errors: List["ErrorDetails"] background_tasks: Optional[StarletteBackgroundTasks] response: Response dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] @@ -582,7 +585,7 @@ async def solve_dependencies( embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} - errors: List[Any] = [] + errors: List[ErrorDetails] = [] if response is None: response = Response() del response.headers["content-length"] @@ -658,7 +661,8 @@ async def solve_dependencies( values.update(query_values) values.update(header_values) values.update(cookie_values) - errors += path_errors + query_errors + header_errors + cookie_errors + for errors_ in (path_errors, query_errors, header_errors, cookie_errors): + errors.extend(errors_) if dependant.body_params: ( body_values, @@ -697,17 +701,15 @@ async def solve_dependencies( def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] -) -> Tuple[Any, List[Any]]: +) -> Tuple[Any, List["ErrorDetails"]]: if value is None: if field.required: return None, [get_missing_field_error(loc=loc)] else: return deepcopy(field.default), [] v_, errors_ = field.validate(value, values, loc=loc) - if isinstance(errors_, ErrorWrapper): - return None, [errors_] - elif isinstance(errors_, list): - new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + if errors_ is not None: + new_errors = _normalize_errors(errors_) return None, new_errors else: return v_, [] @@ -740,9 +742,9 @@ def _get_multidict_value( def request_params_to_args( fields: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], -) -> Tuple[Dict[str, Any], List[Any]]: +) -> Tuple[Dict[str, Any], List["ErrorDetails"]]: values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] + errors: List[ErrorDetails] = [] if not fields: return values, errors @@ -906,9 +908,9 @@ async def request_body_to_args( body_fields: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], embed_body_fields: bool, -) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: +) -> Tuple[Dict[str, Any], List["ErrorDetails"]]: values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] + errors: List[ErrorDetails] = [] assert body_fields, "request_body_to_args() should be called with fields" single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields first_field = body_fields[0] diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86..ca41bdac1 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -1,5 +1,7 @@ -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union +if TYPE_CHECKING: # pragma: nocover + from fastapi._compat import ErrorDetails from pydantic import BaseModel, create_model from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import WebSocketException as StarletteWebSocketException @@ -147,15 +149,15 @@ class FastAPIError(RuntimeError): class ValidationException(Exception): - def __init__(self, errors: Sequence[Any]) -> None: + def __init__(self, errors: Sequence["ErrorDetails"]) -> None: self._errors = errors - def errors(self) -> Sequence[Any]: + def errors(self) -> Sequence["ErrorDetails"]: return self._errors class RequestValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None: super().__init__(errors) self.body = body @@ -165,7 +167,7 @@ class WebSocketRequestValidationError(ValidationException): class ResponseValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence["ErrorDetails"], *, body: Any = None) -> None: super().__init__(errors) self.body = body diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..bb7c0bbc7 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -154,7 +154,6 @@ async def serialize_response( is_coroutine: bool = True, ) -> Any: if field: - errors = [] if not hasattr(field, "serialize"): # pydantic v1 response_content = _prepare_response_content( @@ -164,15 +163,11 @@ async def serialize_response( exclude_none=exclude_none, ) if is_coroutine: - value, errors_ = field.validate(response_content, {}, loc=("response",)) + value, errors = field.validate(response_content, {}, loc=("response",)) else: - value, errors_ = await run_in_threadpool( + value, errors = await run_in_threadpool( field.validate, response_content, {}, loc=("response",) ) - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) if errors: raise ResponseValidationError( errors=_normalize_errors(errors), body=response_content @@ -287,7 +282,6 @@ def get_request_handler( status_code=400, detail="There was an error parsing the body" ) raise http_error from e - errors: List[Any] = [] async with AsyncExitStack() as async_exit_stack: solved_result = await solve_dependencies( request=request, @@ -341,9 +335,7 @@ def get_request_handler( response.body = b"" response.headers.raw.extend(solved_result.response.headers.raw) if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body - ) + validation_error = RequestValidationError(errors, body=body) raise validation_error if response is None: raise FastAPIError( @@ -377,9 +369,7 @@ def get_websocket_app( embed_body_fields=embed_body_fields, ) if solved_result.errors: - raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) - ) + raise WebSocketRequestValidationError(solved_result.errors) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**solved_result.values) diff --git a/tests/test_filter_pydantic_sub_model/app_pv1.py b/tests/test_filter_pydantic_sub_model/app_pv1.py index 657e8c5d1..ab3bf87a8 100644 --- a/tests/test_filter_pydantic_sub_model/app_pv1.py +++ b/tests/test_filter_pydantic_sub_model/app_pv1.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional from fastapi import Depends, FastAPI from pydantic import BaseModel, validator @@ -20,7 +20,7 @@ class ModelA(BaseModel): model_b: ModelB @validator("name") - def lower_username(cls, name: str, values): + def lower_username(cls, name: str, values: Dict[str, Any]) -> str: if not name.endswith("A"): raise ValueError("name must end in A") return name @@ -31,5 +31,7 @@ async def get_model_c() -> ModelC: @app.get("/model/{name}", response_model=ModelA) -async def get_model_a(name: str, model_c=Depends(get_model_c)): +async def get_model_a( + name: str, model_c: ModelC = Depends(get_model_c) +) -> Dict[str, Any]: return {"name": name, "description": "model-a-desc", "model_b": model_c} diff --git a/tests/test_filter_pydantic_sub_model/test_filter_pydantic_sub_model_pv1.py b/tests/test_filter_pydantic_sub_model/test_filter_pydantic_sub_model_pv1.py index 48732dbf0..7d7ef1780 100644 --- a/tests/test_filter_pydantic_sub_model/test_filter_pydantic_sub_model_pv1.py +++ b/tests/test_filter_pydantic_sub_model/test_filter_pydantic_sub_model_pv1.py @@ -1,4 +1,9 @@ +from typing import TYPE_CHECKING, Sequence + import pytest + +if TYPE_CHECKING: # pragma: nocover + from fastapi._compat import ErrorDetails from fastapi.exceptions import ResponseValidationError from fastapi.testclient import TestClient @@ -6,7 +11,7 @@ from ..utils import needs_pydanticv1 @pytest.fixture(name="client") -def get_client(): +def get_client() -> TestClient: from .app_pv1 import app client = TestClient(app) @@ -14,7 +19,7 @@ def get_client(): @needs_pydanticv1 -def test_filter_sub_model(client: TestClient): +def test_filter_sub_model(client: TestClient) -> None: response = client.get("/model/modelA") assert response.status_code == 200, response.text assert response.json() == { @@ -25,10 +30,11 @@ def test_filter_sub_model(client: TestClient): @needs_pydanticv1 -def test_validator_is_cloned(client: TestClient): +def test_validator_is_cloned(client: TestClient) -> None: with pytest.raises(ResponseValidationError) as err: client.get("/model/modelX") - assert err.value.errors() == [ + errors: Sequence[ErrorDetails] = err.value.errors() + assert errors == [ { "loc": ("response", "name"), "msg": "name must end in A", @@ -38,7 +44,7 @@ def test_validator_is_cloned(client: TestClient): @needs_pydanticv1 -def test_openapi_schema(client: TestClient): +def test_openapi_schema(client: TestClient) -> None: response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { diff --git a/tests/test_filter_pydantic_sub_model_pv2.py b/tests/test_filter_pydantic_sub_model_pv2.py index 2e2c26ddc..eb4cdd59b 100644 --- a/tests/test_filter_pydantic_sub_model_pv2.py +++ b/tests/test_filter_pydantic_sub_model_pv2.py @@ -1,8 +1,11 @@ -from typing import Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence import pytest -from dirty_equals import HasRepr, IsDict, IsOneOf +from dirty_equals import HasRepr from fastapi import Depends, FastAPI + +if TYPE_CHECKING: # pragma: nocover + from fastapi._compat import ErrorDetails from fastapi.exceptions import ResponseValidationError from fastapi.testclient import TestClient @@ -10,7 +13,7 @@ from .utils import needs_pydanticv2 @pytest.fixture(name="client") -def get_client(): +def get_client() -> TestClient: from pydantic import BaseModel, ValidationInfo, field_validator app = FastAPI() @@ -27,7 +30,7 @@ def get_client(): foo: ModelB @field_validator("name") - def lower_username(cls, name: str, info: ValidationInfo): + def lower_username(cls, name: str, info: ValidationInfo) -> str: if not name.endswith("A"): raise ValueError("name must end in A") return name @@ -36,7 +39,9 @@ def get_client(): return ModelC(username="test-user", password="test-password") @app.get("/model/{name}", response_model=ModelA) - async def get_model_a(name: str, model_c=Depends(get_model_c)): + async def get_model_a( + name: str, model_c: ModelC = Depends(get_model_c) + ) -> Dict[str, Any]: return {"name": name, "description": "model-a-desc", "foo": model_c} client = TestClient(app) @@ -44,7 +49,7 @@ def get_client(): @needs_pydanticv2 -def test_filter_sub_model(client: TestClient): +def test_filter_sub_model(client: TestClient) -> None: response = client.get("/model/modelA") assert response.status_code == 200, response.text assert response.json() == { @@ -55,32 +60,23 @@ def test_filter_sub_model(client: TestClient): @needs_pydanticv2 -def test_validator_is_cloned(client: TestClient): +def test_validator_is_cloned(client: TestClient) -> None: with pytest.raises(ResponseValidationError) as err: client.get("/model/modelX") - assert err.value.errors() == [ - IsDict( - { - "type": "value_error", - "loc": ("response", "name"), - "msg": "Value error, name must end in A", - "input": "modelX", - "ctx": {"error": HasRepr("ValueError('name must end in A')")}, - } - ) - | IsDict( - # TODO remove when deprecating Pydantic v1 - { - "loc": ("response", "name"), - "msg": "name must end in A", - "type": "value_error", - } - ) + errors: Sequence[ErrorDetails] = err.value.errors() + assert errors == [ + { + "type": "value_error", + "loc": ("response", "name"), + "msg": "Value error, name must end in A", + "input": "modelX", + "ctx": {"error": HasRepr("ValueError('name must end in A')")}, + } ] @needs_pydanticv2 -def test_openapi_schema(client: TestClient): +def test_openapi_schema(client: TestClient) -> None: response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == { @@ -137,23 +133,14 @@ def test_openapi_schema(client: TestClient): }, "ModelA": { "title": "ModelA", - "required": IsOneOf( - ["name", "description", "foo"], - # TODO remove when deprecating Pydantic v1 - ["name", "foo"], - ), + "required": ["name", "foo"], "type": "object", "properties": { "name": {"title": "Name", "type": "string"}, - "description": IsDict( - { - "title": "Description", - "anyOf": [{"type": "string"}, {"type": "null"}], - } - ) - | - # TODO remove when deprecating Pydantic v1 - IsDict({"title": "Description", "type": "string"}), + "description": { + "title": "Description", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, "foo": {"$ref": "#/components/schemas/ModelB"}, }, },