|
|
@ -119,7 +119,9 @@ def get_param_sub_dependant( |
|
|
|
|
|
|
|
|
|
|
|
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: |
|
|
|
assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency" |
|
|
|
assert callable( |
|
|
|
depends.dependency |
|
|
|
), "A parameter-less dependency must have a callable dependency" |
|
|
|
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) |
|
|
|
|
|
|
|
|
|
|
@ -140,7 +142,9 @@ def get_sub_dependant( |
|
|
|
use_scopes: List[str] = [] |
|
|
|
if isinstance(dependency, (OAuth2, OpenIdConnect)): |
|
|
|
use_scopes = security_scopes |
|
|
|
security_requirement = SecurityRequirement(security_scheme=dependency, scopes=use_scopes) |
|
|
|
security_requirement = SecurityRequirement( |
|
|
|
security_scheme=dependency, scopes=use_scopes |
|
|
|
) |
|
|
|
sub_dependant = get_dependant( |
|
|
|
path=path, |
|
|
|
call=dependency, |
|
|
@ -179,7 +183,9 @@ def get_flat_dependant( |
|
|
|
for sub_dependant in dependant.dependencies: |
|
|
|
if skip_repeats and sub_dependant.cache_key in visited: |
|
|
|
continue |
|
|
|
flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) |
|
|
|
flat_sub = get_flat_dependant( |
|
|
|
sub_dependant, skip_repeats=skip_repeats, visited=visited |
|
|
|
) |
|
|
|
flat_dependant.path_params.extend(flat_sub.path_params) |
|
|
|
flat_dependant.query_params.extend(flat_sub.query_params) |
|
|
|
flat_dependant.header_params.extend(flat_sub.header_params) |
|
|
@ -209,10 +215,14 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: |
|
|
|
for param in fields: |
|
|
|
query_extra_info[param] = dict(fields[param].__repr_args__()) |
|
|
|
if "alias" in query_extra_info[param]: |
|
|
|
query_extra_info[query_extra_info[param]["alias"]] = dict(fields[param].__repr_args__()) |
|
|
|
query_extra_info[query_extra_info[param]["alias"]] = dict( |
|
|
|
fields[param].__repr_args__() |
|
|
|
) |
|
|
|
alias_dict[query_extra_info[param]["alias"]] = param |
|
|
|
query_extra_info[param]["default"] = ( |
|
|
|
Required if getattr(fields[param], "required", False) else fields[param].default |
|
|
|
Required |
|
|
|
if getattr(fields[param], "required", False) |
|
|
|
else fields[param].default |
|
|
|
) |
|
|
|
typed_params = [] |
|
|
|
|
|
|
@ -306,7 +316,9 @@ def get_dependant( |
|
|
|
type_annotation=type_annotation, |
|
|
|
dependant=dependant, |
|
|
|
): |
|
|
|
assert param_field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}" |
|
|
|
assert ( |
|
|
|
param_field is None |
|
|
|
), f"Cannot specify multiple FastAPI annotations for {param_name!r}" |
|
|
|
continue |
|
|
|
assert param_field is not None |
|
|
|
if is_body_param(param_field=param_field, is_path_param=is_path_param): |
|
|
@ -316,7 +328,9 @@ def get_dependant( |
|
|
|
return dependant |
|
|
|
|
|
|
|
|
|
|
|
def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]: |
|
|
|
def add_non_field_param_to_dependency( |
|
|
|
*, param_name: str, type_annotation: Any, dependant: Dependant |
|
|
|
) -> Optional[bool]: |
|
|
|
if lenient_issubclass(type_annotation, Request): |
|
|
|
dependant.request_param_name = param_name |
|
|
|
return True |
|
|
@ -348,17 +362,26 @@ def analyze_param( |
|
|
|
field_info = None |
|
|
|
depends = None |
|
|
|
type_annotation: Any = Any |
|
|
|
if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: |
|
|
|
if ( |
|
|
|
annotation is not inspect.Signature.empty |
|
|
|
and get_origin(annotation) is Annotated |
|
|
|
): |
|
|
|
annotated_args = get_args(annotation) |
|
|
|
type_annotation = annotated_args[0] |
|
|
|
fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))] |
|
|
|
fastapi_annotations = [ |
|
|
|
arg |
|
|
|
for arg in annotated_args[1:] |
|
|
|
if isinstance(arg, (FieldInfo, params.Depends)) |
|
|
|
] |
|
|
|
assert ( |
|
|
|
len(fastapi_annotations) <= 1 |
|
|
|
), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" |
|
|
|
fastapi_annotation = next(iter(fastapi_annotations), None) |
|
|
|
if isinstance(fastapi_annotation, FieldInfo): |
|
|
|
# Copy `field_info` because we mutate `field_info.default` below. |
|
|
|
field_info = copy_field_info(field_info=fastapi_annotation, annotation=annotation) |
|
|
|
field_info = copy_field_info( |
|
|
|
field_info=fastapi_annotation, annotation=annotation |
|
|
|
) |
|
|
|
assert field_info.default is Undefined or field_info.default is Required, ( |
|
|
|
f"`{field_info.__class__.__name__}` default value cannot be set in" |
|
|
|
f" `Annotated` for {param_name!r}. Set the default value with `=` instead." |
|
|
@ -375,7 +398,8 @@ def analyze_param( |
|
|
|
|
|
|
|
if isinstance(value, params.Depends): |
|
|
|
assert depends is None, ( |
|
|
|
"Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" |
|
|
|
"Cannot specify `Depends` in `Annotated` and default value" |
|
|
|
f" together for {param_name!r}" |
|
|
|
) |
|
|
|
assert field_info is None, ( |
|
|
|
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" |
|
|
@ -384,7 +408,8 @@ def analyze_param( |
|
|
|
depends = value |
|
|
|
elif isinstance(value, FieldInfo): |
|
|
|
assert field_info is None, ( |
|
|
|
"Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" |
|
|
|
"Cannot specify FastAPI annotations in `Annotated` and default value" |
|
|
|
f" together for {param_name!r}" |
|
|
|
) |
|
|
|
field_info = value |
|
|
|
if PYDANTIC_V2: |
|
|
@ -405,7 +430,9 @@ def analyze_param( |
|
|
|
), |
|
|
|
): |
|
|
|
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" |
|
|
|
assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}" |
|
|
|
assert ( |
|
|
|
field_info is None |
|
|
|
), f"Cannot specify FastAPI annotation for type {type_annotation!r}" |
|
|
|
elif field_info is None and depends is None: |
|
|
|
default_value = value if value is not inspect.Signature.empty else Required |
|
|
|
if is_path_param: |
|
|
@ -413,9 +440,9 @@ def analyze_param( |
|
|
|
# parameter might sometimes be a path parameter and sometimes not. See |
|
|
|
# `tests/test_infer_param_optionality.py` for an example. |
|
|
|
field_info = params.Path(annotation=type_annotation) |
|
|
|
elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation( |
|
|
|
elif is_uploadfile_or_nonable_uploadfile_annotation( |
|
|
|
type_annotation |
|
|
|
): |
|
|
|
) or is_uploadfile_sequence_annotation(type_annotation): |
|
|
|
field_info = params.File(annotation=type_annotation, default=default_value) |
|
|
|
elif not field_annotation_is_scalar(annotation=type_annotation): |
|
|
|
field_info = params.Body(annotation=type_annotation, default=default_value) |
|
|
@ -426,9 +453,13 @@ def analyze_param( |
|
|
|
if field_info is not None: |
|
|
|
if is_path_param: |
|
|
|
assert isinstance(field_info, params.Path), ( |
|
|
|
f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" |
|
|
|
f"Cannot use `{field_info.__class__.__name__}` for path param" |
|
|
|
f" {param_name!r}" |
|
|
|
) |
|
|
|
elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None: |
|
|
|
elif ( |
|
|
|
isinstance(field_info, params.Param) |
|
|
|
and getattr(field_info, "in_", None) is None |
|
|
|
): |
|
|
|
field_info.in_ = params.ParamTypes.query |
|
|
|
use_annotation = get_annotation_from_field_info( |
|
|
|
type_annotation, |
|
|
@ -454,11 +485,15 @@ def analyze_param( |
|
|
|
|
|
|
|
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: |
|
|
|
if is_path_param: |
|
|
|
assert is_scalar_field(field=param_field), "Path params must be of one of the supported types" |
|
|
|
assert is_scalar_field( |
|
|
|
field=param_field |
|
|
|
), "Path params must be of one of the supported types" |
|
|
|
return False |
|
|
|
elif is_scalar_field(field=param_field): |
|
|
|
return False |
|
|
|
elif isinstance(param_field.field_info, (params.Query, params.Header)) and is_scalar_sequence_field(param_field): |
|
|
|
elif isinstance( |
|
|
|
param_field.field_info, (params.Query, params.Header) |
|
|
|
) and is_scalar_sequence_field(param_field): |
|
|
|
return False |
|
|
|
else: |
|
|
|
assert isinstance( |
|
|
@ -505,7 +540,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: |
|
|
|
return inspect.isgeneratorfunction(dunder_call) |
|
|
|
|
|
|
|
|
|
|
|
async def solve_generator(*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]) -> Any: |
|
|
|
async def solve_generator( |
|
|
|
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] |
|
|
|
) -> Any: |
|
|
|
if is_gen_callable(call): |
|
|
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) |
|
|
|
elif is_async_gen_callable(call): |
|
|
@ -539,12 +576,19 @@ async def solve_dependencies( |
|
|
|
sub_dependant: Dependant |
|
|
|
for sub_dependant in dependant.dependencies: |
|
|
|
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) |
|
|
|
sub_dependant.cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) |
|
|
|
sub_dependant.cache_key = cast( |
|
|
|
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key |
|
|
|
) |
|
|
|
call = sub_dependant.call |
|
|
|
use_sub_dependant = sub_dependant |
|
|
|
if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides: |
|
|
|
if ( |
|
|
|
dependency_overrides_provider |
|
|
|
and dependency_overrides_provider.dependency_overrides |
|
|
|
): |
|
|
|
original_call = sub_dependant.call |
|
|
|
call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) |
|
|
|
call = getattr( |
|
|
|
dependency_overrides_provider, "dependency_overrides", {} |
|
|
|
).get(original_call, original_call) |
|
|
|
use_path: str = sub_dependant.path # type: ignore |
|
|
|
use_sub_dependant = get_dependant( |
|
|
|
path=use_path, |
|
|
@ -578,7 +622,9 @@ async def solve_dependencies( |
|
|
|
elif is_gen_callable(call) or is_async_gen_callable(call): |
|
|
|
stack = request.scope.get("fastapi_astack") |
|
|
|
assert isinstance(stack, AsyncExitStack) |
|
|
|
solved = await solve_generator(call=call, stack=stack, sub_values=sub_values) |
|
|
|
solved = await solve_generator( |
|
|
|
call=call, stack=stack, sub_values=sub_values |
|
|
|
) |
|
|
|
elif is_coroutine_callable(call): |
|
|
|
solved = await call(**sub_values) |
|
|
|
else: |
|
|
@ -587,10 +633,18 @@ async def solve_dependencies( |
|
|
|
values[sub_dependant.name] = solved |
|
|
|
if sub_dependant.cache_key not in dependency_cache: |
|
|
|
dependency_cache[sub_dependant.cache_key] = solved |
|
|
|
path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params) |
|
|
|
query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params) |
|
|
|
header_values, header_errors = request_params_to_args(dependant.header_params, request.headers) |
|
|
|
cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies) |
|
|
|
path_values, path_errors = request_params_to_args( |
|
|
|
dependant.path_params, request.path_params |
|
|
|
) |
|
|
|
query_values, query_errors = request_params_to_args( |
|
|
|
dependant.query_params, request.query_params |
|
|
|
) |
|
|
|
header_values, header_errors = request_params_to_args( |
|
|
|
dependant.header_params, request.headers |
|
|
|
) |
|
|
|
cookie_values, cookie_errors = request_params_to_args( |
|
|
|
dependant.cookie_params, request.cookies |
|
|
|
) |
|
|
|
values.update(path_values) |
|
|
|
values.update(query_values) |
|
|
|
values.update(header_values) |
|
|
@ -618,7 +672,9 @@ async def solve_dependencies( |
|
|
|
if dependant.response_param_name: |
|
|
|
values[dependant.response_param_name] = response |
|
|
|
if dependant.security_scopes_param_name: |
|
|
|
values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.security_scopes) |
|
|
|
values[dependant.security_scopes_param_name] = SecurityScopes( |
|
|
|
scopes=dependant.security_scopes |
|
|
|
) |
|
|
|
return values, errors, background_tasks, response, dependency_cache |
|
|
|
|
|
|
|
|
|
|
@ -629,12 +685,16 @@ def request_params_to_args( |
|
|
|
values = {} |
|
|
|
errors = [] |
|
|
|
for field in required_params: |
|
|
|
if is_scalar_sequence_field(field) and isinstance(received_params, (QueryParams, Headers)): |
|
|
|
if is_scalar_sequence_field(field) and isinstance( |
|
|
|
received_params, (QueryParams, Headers) |
|
|
|
): |
|
|
|
value = received_params.getlist(field.alias) or field.default |
|
|
|
else: |
|
|
|
value = received_params.get(field.alias) |
|
|
|
field_info = field.field_info |
|
|
|
assert isinstance(field_info, params.Param), "Params must be subclasses of Param" |
|
|
|
assert isinstance( |
|
|
|
field_info, params.Param |
|
|
|
), "Params must be subclasses of Param" |
|
|
|
loc = (field_info.in_.value, field.alias) |
|
|
|
if value is None: |
|
|
|
if field.required: |
|
|
@ -687,21 +747,35 @@ async def request_body_to_args( |
|
|
|
if ( |
|
|
|
value is None |
|
|
|
or (isinstance(field_info, params.Form) and value == "") |
|
|
|
or (isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0) |
|
|
|
or ( |
|
|
|
isinstance(field_info, params.Form) |
|
|
|
and is_sequence_field(field) |
|
|
|
and len(value) == 0 |
|
|
|
) |
|
|
|
): |
|
|
|
if field.required: |
|
|
|
errors.append(get_missing_field_error(loc)) |
|
|
|
else: |
|
|
|
values[field.name] = deepcopy(field.default) |
|
|
|
continue |
|
|
|
if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile): |
|
|
|
if ( |
|
|
|
isinstance(field_info, params.File) |
|
|
|
and is_bytes_field(field) |
|
|
|
and isinstance(value, UploadFile) |
|
|
|
): |
|
|
|
value = await value.read() |
|
|
|
elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value): |
|
|
|
elif ( |
|
|
|
is_bytes_sequence_field(field) |
|
|
|
and isinstance(field_info, params.File) |
|
|
|
and value_is_sequence(value) |
|
|
|
): |
|
|
|
# For types |
|
|
|
assert isinstance(value, sequence_types) # type: ignore[arg-type] |
|
|
|
results: List[Union[bytes, str]] = [] |
|
|
|
|
|
|
|
async def process_fn(fn: Callable[[], Coroutine[Any, Any, Any]]) -> None: |
|
|
|
async def process_fn( |
|
|
|
fn: Callable[[], Coroutine[Any, Any, Any]] |
|
|
|
) -> None: |
|
|
|
result = await fn() |
|
|
|
results.append(result) # noqa: B023 |
|
|
|
|
|
|
@ -738,7 +812,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: |
|
|
|
for param in flat_dependant.body_params: |
|
|
|
setattr(param.field_info, "embed", True) # noqa: B010 |
|
|
|
model_name = "Body_" + name |
|
|
|
BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name) |
|
|
|
BodyModel = create_body_model( |
|
|
|
fields=flat_dependant.body_params, model_name=model_name |
|
|
|
) |
|
|
|
required = any(True for f in flat_dependant.body_params if f.required) |
|
|
|
BodyFieldInfo_kwargs: Dict[str, Any] = { |
|
|
|
"annotation": BodyModel, |
|
|
@ -754,7 +830,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: |
|
|
|
BodyFieldInfo = params.Body |
|
|
|
|
|
|
|
body_param_media_types = [ |
|
|
|
f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) |
|
|
|
f.field_info.media_type |
|
|
|
for f in flat_dependant.body_params |
|
|
|
if isinstance(f.field_info, params.Body) |
|
|
|
] |
|
|
|
if len(set(body_param_media_types)) == 1: |
|
|
|
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] |
|
|
|