@ -4,6 +4,7 @@ from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
TYPE_CHECKING ,
Any ,
Callable ,
Deque ,
@ -23,7 +24,14 @@ from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel , create_model
from pydantic . version import VERSION as PYDANTIC_VERSION
from starlette . datastructures import UploadFile
from typing_extensions import Annotated , Literal , get_args , get_origin
from typing_extensions import (
Annotated ,
Literal ,
TypeAlias ,
assert_never ,
get_args ,
get_origin ,
)
PYDANTIC_VERSION_MINOR_TUPLE = tuple ( int ( x ) for x in PYDANTIC_VERSION . split ( " . " ) [ : 2 ] )
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE [ 0 ] == 2
@ -60,6 +68,7 @@ if PYDANTIC_V2:
from pydantic . json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic . json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import ErrorDetails as ErrorDetails
from pydantic_core import PydanticUndefined , PydanticUndefinedType
from pydantic_core import Url as Url
@ -84,6 +93,9 @@ if PYDANTIC_V2:
class ErrorWrapper ( Exception ) :
pass
# See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/error_wrappers.py#L45-L47.
ErrorList : TypeAlias = Union [ Sequence [ " ErrorList " ] , ErrorWrapper ]
@dataclass
class ModelField :
field_info : FieldInfo
@ -117,22 +129,25 @@ if PYDANTIC_V2:
return Undefined
return self . field_info . get_default ( call_default_factory = True )
# See https://github.com/pydantic/pydantic/blob/bb18ac5/pydantic/fields.py#L850-L852 for the signature.
def validate (
self ,
value : Any ,
values : Dict [ str , Any ] = { } , # noqa: B006
* ,
loc : Tuple [ Union [ int , str ] , . . . ] = ( ) ,
) - > Tuple [ Any , Union [ List [ Dict [ str , Any ] ] , None ] ] :
) - > Tuple [ Any , Union [ ErrorList , Sequence [ ErrorDetails ] , None ] ] :
try :
return (
self . _type_adapter . validate_python ( value , from_attributes = True ) ,
None ,
)
except ValidationError as exc :
return None , _regenerate_error_with_loc (
errors = exc . errors ( include_url = False ) , loc_prefix = loc
)
errors : List [ ErrorDetails ] = [
{ * * err , " loc " : loc + err [ " loc " ] }
for err in exc . errors ( include_url = False )
]
return None , errors
def serialize (
self ,
@ -169,7 +184,13 @@ if PYDANTIC_V2:
) - > Any :
return annotation
def _normalize_errors ( errors : Sequence [ Any ] ) - > List [ Dict [ str , Any ] ] :
def _normalize_errors (
errors : Union [ ErrorList , Sequence [ ErrorDetails ] ] ,
) - > List [ ErrorDetails ] :
assert isinstance ( errors , Sequence ) , type ( errors )
for error in errors :
assert not isinstance ( error , ErrorWrapper )
assert not isinstance ( error , Sequence )
return errors # type: ignore[return-value]
def _model_rebuild ( model : Type [ BaseModel ] ) - > None :
@ -267,12 +288,12 @@ if PYDANTIC_V2:
assert issubclass ( origin_type , sequence_types ) # type: ignore[arg-type]
return sequence_annotation_to_type [ origin_type ] ( value ) # type: ignore[no-any-return]
def get_missing_field_error ( loc : Tuple [ str , . . . ] ) - > Dict [ str , Any ] :
error = ValidationError . from_exception_data (
def get_missing_field_error ( loc : Tuple [ str , . . . ] ) - > ErrorDetails :
[ error ] = ValidationError . from_exception_data (
" Field required " , [ { " type " : " missing " , " loc " : loc , " input " : { } } ]
) . errors ( include_url = False ) [ 0 ]
) . errors ( include_url = False , include_input = False )
error [ " input " ] = None
return error # type: ignore[return-value]
return error
def create_body_model (
* , fields : Sequence [ ModelField ] , model_name : str
@ -297,6 +318,14 @@ else:
from pydantic . class_validators import ( # type: ignore[no-redef]
Validator as Validator ,
)
if TYPE_CHECKING : # pragma: nocover
from pydantic . error_wrappers import ( # type: ignore[no-redef]
ErrorDict as ErrorDetails ,
)
from pydantic . error_wrappers import ( # type: ignore[no-redef]
ErrorList as ErrorList ,
)
from pydantic . error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper ,
)
@ -427,18 +456,23 @@ else:
return True
return False
def _normalize_errors ( errors : Sequence [ Any ] ) - > List [ Dict [ str , Any ] ] :
use_errors : List [ Any ] = [ ]
for error in errors :
if isinstance ( error , ErrorWrapper ) :
new_errors = ValidationError ( # type: ignore[call-arg]
errors = [ error ] , model = RequestErrorModel
def _normalize_errors (
errors : Union [ ErrorList , Sequence [ " ErrorDetails " ] ] ,
) - > List [ " ErrorDetails " ] :
use_errors : List [ ErrorDetails ] = [ ]
if isinstance ( errors , ErrorWrapper ) :
use_errors . extend (
ValidationError ( # type: ignore[call-arg]
errors = [ errors ] , model = RequestErrorModel
) . errors ( )
use_errors . extend ( new_errors )
elif isinstance ( error , list ) :
)
elif isinstance ( errors , Sequence ) :
for error in errors :
assert not isinstance ( error , dict )
use_errors . extend ( _normalize_errors ( error ) )
else :
use_errors . append ( error )
return use_errors
else :
assert_never ( errors ) # pragma: no cover
return use_errors
def _model_rebuild ( model : Type [ BaseModel ] ) - > None :
@ -509,10 +543,10 @@ else:
def serialize_sequence_value ( * , field : ModelField , value : Any ) - > Sequence [ Any ] :
return sequence_shape_to_type [ field . shape ] ( value ) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error ( loc : Tuple [ str , . . . ] ) - > Dict [ str , Any ] :
def get_missing_field_error ( loc : Tuple [ str , . . . ] ) - > " ErrorDetails " :
missing_field_error = ErrorWrapper ( MissingError ( ) , loc = loc ) # type: ignore[call-arg]
new_error = ValidationError ( [ missing_field_error ] , RequestErrorModel )
return new_error . errors ( ) [ 0 ] # type: ignore[return-value]
[ new_error ] = ValidationError ( [ missing_field_error ] , RequestErrorModel ) . errors ( )
return new_error
def create_body_model (
* , fields : Sequence [ ModelField ] , model_name : str
@ -526,17 +560,6 @@ else:
return list ( model . __fields__ . values ( ) ) # type: ignore[attr-defined]
def _regenerate_error_with_loc (
* , errors : Sequence [ Any ] , loc_prefix : Tuple [ Union [ str , int ] , . . . ]
) - > List [ Dict [ str , Any ] ] :
updated_loc_errors : List [ Any ] = [
{ * * err , " loc " : loc_prefix + err . get ( " loc " , ( ) ) }
for err in _normalize_errors ( errors )
]
return updated_loc_errors
def _annotation_is_sequence ( annotation : Union [ Type [ Any ] , None ] ) - > bool :
if lenient_issubclass ( annotation , ( str , bytes ) ) :
return False