@ -59,7 +59,13 @@ from fastapi.utils import create_model_field, get_path_param_names
from pydantic . fields import FieldInfo
from starlette . background import BackgroundTasks as StarletteBackgroundTasks
from starlette . concurrency import run_in_threadpool
from starlette . datastructures import FormData , Headers , QueryParams , UploadFile
from starlette . datastructures import (
FormData ,
Headers ,
ImmutableMultiDict ,
QueryParams ,
UploadFile ,
)
from starlette . requests import HTTPConnection , Request
from starlette . responses import Response
from starlette . websockets import WebSocket
@ -282,7 +288,7 @@ def get_dependant(
) , f " Cannot specify multiple FastAPI annotations for { param_name !r} "
continue
assert param_details . field is not None
if is_body_param ( param_field = param_details . field , is_path_param = is_path_param ) :
if isinstance ( param_details . field . field_info , params . Body ) :
dependant . body_params . append ( param_details . field )
else :
add_param_to_fields ( field = param_details . field , dependant = dependant )
@ -466,29 +472,16 @@ def analyze_param(
required = field_info . default in ( Required , Undefined ) ,
field_info = field_info ,
)
if is_path_param :
assert is_scalar_field (
field = field
) , " Path params must be of one of the supported types "
elif isinstance ( field_info , params . Query ) :
assert is_scalar_field ( field ) or is_scalar_sequence_field ( field )
return ParamDetails ( type_annotation = type_annotation , depends = depends , field = field )
def is_body_param ( * , param_field : ModelField , is_path_param : bool ) - > bool :
if is_path_param :
assert is_scalar_field (
field = param_field
) , " Path params must be of one of the supported types "
return False
elif is_scalar_field ( field = param_field ) :
return False
elif isinstance (
param_field . field_info , ( params . Query , params . Header )
) and is_scalar_sequence_field ( param_field ) :
return False
else :
assert isinstance (
param_field . field_info , params . Body
) , f " Param: { param_field . name } can only be a request body, using Body() "
return True
def add_param_to_fields ( * , field : ModelField , dependant : Dependant ) - > None :
field_info = field . field_info
field_info_in = getattr ( field_info , " in_ " , None )
@ -557,6 +550,7 @@ async def solve_dependencies(
dependency_overrides_provider : Optional [ Any ] = None ,
dependency_cache : Optional [ Dict [ Tuple [ Callable [ . . . , Any ] , Tuple [ str ] ] , Any ] ] = None ,
async_exit_stack : AsyncExitStack ,
embed_body_fields : bool ,
) - > SolvedDependency :
values : Dict [ str , Any ] = { }
errors : List [ Any ] = [ ]
@ -598,6 +592,7 @@ async def solve_dependencies(
dependency_overrides_provider = dependency_overrides_provider ,
dependency_cache = dependency_cache ,
async_exit_stack = async_exit_stack ,
embed_body_fields = embed_body_fields ,
)
background_tasks = solved_result . background_tasks
dependency_cache . update ( solved_result . dependency_cache )
@ -640,7 +635,9 @@ async def solve_dependencies(
body_values ,
body_errors ,
) = await request_body_to_args ( # body_params checked above
required_params = dependant . body_params , received_body = body
body_fields = dependant . body_params ,
received_body = body ,
embed_body_fields = embed_body_fields ,
)
values . update ( body_values )
errors . extend ( body_errors )
@ -669,138 +666,185 @@ async def solve_dependencies(
)
def _validate_value_with_model_field (
* , field : ModelField , value : Any , values : Dict [ str , Any ] , loc : Tuple [ str , . . . ]
) - > Tuple [ Any , List [ Any ] ] :
if value is None :
if field . required :
return None , [ get_missing_field_error ( loc = loc ) ]
else :
return deepcopy ( field . default ) , [ ]
v_ , errors_ = field . validate ( value , values , loc = loc )
if isinstance ( errors_ , ErrorWrapper ) :
return None , [ errors_ ]
elif isinstance ( errors_ , list ) :
new_errors = _regenerate_error_with_loc ( errors = errors_ , loc_prefix = ( ) )
return None , new_errors
else :
return v_ , [ ]
def _get_multidict_value ( field : ModelField , values : Mapping [ str , Any ] ) - > Any :
if is_sequence_field ( field ) and isinstance ( values , ( ImmutableMultiDict , Headers ) ) :
value = values . getlist ( field . alias )
else :
value = values . get ( field . alias , None )
if (
value is None
or (
isinstance ( field . field_info , params . Form )
and isinstance ( value , str ) # For type checks
and value == " "
)
or ( is_sequence_field ( field ) and len ( value ) == 0 )
) :
if field . required :
return
else :
return deepcopy ( field . default )
return value
def request_params_to_args (
required_params : Sequence [ ModelField ] ,
field s: Sequence [ ModelField ] ,
received_params : Union [ Mapping [ str , Any ] , QueryParams , Headers ] ,
) - > Tuple [ Dict [ str , Any ] , List [ Any ] ] :
values = { }
values : Dict [ str , Any ] = { }
errors = [ ]
for field in required_params :
if is_scalar_sequence_field ( field ) and isinstance (
received_params , ( QueryParams , Headers )
) :
value = received_params . getlist ( field . alias ) or field . default
else :
value = received_params . get ( field . alias )
for field in fields :
value = _get_multidict_value ( field , received_params )
field_info = field . field_info
assert isinstance (
field_info , params . Param
) , " Params must be subclasses of Param "
loc = ( field_info . in_ . value , field . alias )
if value is None :
if field . required :
errors . append ( get_missing_field_error ( loc = loc ) )
else :
values [ field . name ] = deepcopy ( field . default )
continue
v_ , errors_ = field . validate ( value , values , loc = loc )
if isinstance ( errors_ , ErrorWrapper ) :
errors . append ( errors_ )
elif isinstance ( errors_ , list ) :
new_errors = _regenerate_error_with_loc ( errors = errors_ , loc_prefix = ( ) )
errors . extend ( new_errors )
v_ , errors_ = _validate_value_with_model_field (
field = field , value = value , values = values , loc = loc
)
if errors_ :
errors . extend ( errors_ )
else :
values [ field . name ] = v_
return values , errors
def _should_embed_body_fields ( fields : List [ ModelField ] ) - > bool :
if not fields :
return False
# More than one dependency could have the same field, it would show up as multiple
# fields but it's the same one, so count them by name
body_param_names_set = { field . name for field in fields }
# A top level field has to be a single field, not multiple
if len ( body_param_names_set ) > 1 :
return True
first_field = fields [ 0 ]
# If it explicitly specifies it is embedded, it has to be embedded
if getattr ( first_field . field_info , " embed " , None ) :
return True
# If it's a Form (or File) field, it has to be a BaseModel to be top level
# otherwise it has to be embedded, so that the key value pair can be extracted
if isinstance ( first_field . field_info , params . Form ) :
return True
return False
async def _extract_form_body (
body_fields : List [ ModelField ] ,
received_body : FormData ,
) - > Dict [ str , Any ] :
values = { }
first_field = body_fields [ 0 ]
first_field_info = first_field . field_info
for field in body_fields :
value = _get_multidict_value ( field , received_body )
if (
isinstance ( first_field_info , params . File )
and is_bytes_field ( field )
and isinstance ( value , UploadFile )
) :
value = await value . read ( )
elif (
is_bytes_sequence_field ( field )
and isinstance ( first_field_info , params . File )
and value_is_sequence ( value )
) :
# For types
assert isinstance ( value , sequence_types ) # type: ignore[arg-type]
results : List [ Union [ bytes , str ] ] = [ ]
async def process_fn (
fn : Callable [ [ ] , Coroutine [ Any , Any , Any ] ] ,
) - > None :
result = await fn ( )
results . append ( result ) # noqa: B023
async with anyio . create_task_group ( ) as tg :
for sub_value in value :
tg . start_soon ( process_fn , sub_value . read )
value = serialize_sequence_value ( field = field , value = results )
values [ field . name ] = value
return values
async def request_body_to_args (
required_params : List [ ModelField ] ,
body_field s: List [ ModelField ] ,
received_body : Optional [ Union [ Dict [ str , Any ] , FormData ] ] ,
embed_body_fields : bool ,
) - > Tuple [ Dict [ str , Any ] , List [ Dict [ str , Any ] ] ] :
values = { }
values : Dict [ str , Any ] = { }
errors : List [ Dict [ str , Any ] ] = [ ]
if required_params :
field = required_params [ 0 ]
field_info = field . field_info
embed = getattr ( field_info , " embed " , None )
field_alias_omitted = len ( required_params ) == 1 and not embed
if field_alias_omitted :
received_body = { field . alias : received_body }
for field in required_params :
loc : Tuple [ str , . . . ]
if field_alias_omitted :
loc = ( " body " , )
else :
loc = ( " body " , field . alias )
value : Optional [ Any ] = None
if received_body is not None :
if ( is_sequence_field ( field ) ) and isinstance ( received_body , FormData ) :
value = received_body . getlist ( field . alias )
else :
try :
value = received_body . get ( field . alias )
except AttributeError :
errors . append ( get_missing_field_error ( loc ) )
continue
if (
value is None
or ( isinstance ( field_info , params . Form ) and value == " " )
or (
isinstance ( field_info , params . Form )
and is_sequence_field ( field )
and len ( value ) == 0
)
) :
if field . required :
errors . append ( get_missing_field_error ( loc ) )
else :
values [ field . name ] = deepcopy ( field . default )
assert body_fields , " request_body_to_args() should be called with fields "
single_not_embedded_field = len ( body_fields ) == 1 and not embed_body_fields
first_field = body_fields [ 0 ]
body_to_process = received_body
if isinstance ( received_body , FormData ) :
body_to_process = await _extract_form_body ( body_fields , received_body )
if single_not_embedded_field :
loc : Tuple [ str , . . . ] = ( " body " , )
v_ , errors_ = _validate_value_with_model_field (
field = first_field , value = body_to_process , values = values , loc = loc
)
return { first_field . name : v_ } , errors_
for field in body_fields :
loc = ( " body " , field . alias )
value : Optional [ Any ] = None
if body_to_process is not None :
try :
value = body_to_process . get ( field . alias )
# If the received body is a list, not a dict
except AttributeError :
errors . append ( get_missing_field_error ( loc ) )
continue
if (
isinstance ( field_info , params . File )
and is_bytes_field ( field )
and isinstance ( value , UploadFile )
) :
value = await value . read ( )
elif (
is_bytes_sequence_field ( field )
and isinstance ( field_info , params . File )
and value_is_sequence ( value )
) :
# For types
assert isinstance ( value , sequence_types ) # type: ignore[arg-type]
results : List [ Union [ bytes , str ] ] = [ ]
async def process_fn (
fn : Callable [ [ ] , Coroutine [ Any , Any , Any ] ] ,
) - > None :
result = await fn ( )
results . append ( result ) # noqa: B023
async with anyio . create_task_group ( ) as tg :
for sub_value in value :
tg . start_soon ( process_fn , sub_value . read )
value = serialize_sequence_value ( field = field , value = results )
v_ , errors_ = field . validate ( value , values , loc = loc )
if isinstance ( errors_ , list ) :
errors . extend ( errors_ )
elif errors_ :
errors . append ( errors_ )
else :
values [ field . name ] = v_
v_ , errors_ = _validate_value_with_model_field (
field = field , value = value , values = values , loc = loc
)
if errors_ :
errors . extend ( errors_ )
else :
values [ field . name ] = v_
return values , errors
def get_body_field ( * , dependant : Dependant , name : str ) - > Optional [ ModelField ] :
flat_dependant = get_flat_dependant ( dependant )
def get_body_field (
* , flat_dependant : Dependant , name : str , embed_body_fields : bool
) - > Optional [ ModelField ] :
"""
Get a ModelField representing the request body for a path operation , combining
all body parameters into a single field if necessary .
Used to check if it ' s form data (with `isinstance(body_field, params.Form)`)
or JSON and to generate the JSON Schema for a request body .
This is * * not * * used to validate / parse the request body , that ' s done with each
individual body parameter .
"""
if not flat_dependant . body_params :
return None
first_param = flat_dependant . body_params [ 0 ]
field_info = first_param . field_info
embed = getattr ( field_info , " embed " , None )
body_param_names_set = { param . name for param in flat_dependant . body_params }
if len ( body_param_names_set ) == 1 and not embed :
if not embed_body_fields :
return first_param
# If one field requires to embed, all have to be embedded
# in case a sub-dependency is evaluated with a single unique body field
# That is combined (embedded) with other body fields
for param in flat_dependant . body_params :
setattr ( param . field_info , " embed " , True ) # noqa: B010
model_name = " Body_ " + name
BodyModel = create_body_model (
fields = flat_dependant . body_params , model_name = model_name