Browse Source

🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

pull/4573/head
pre-commit-ci[bot] 1 year ago
parent
commit
21532f7928
  1. 152
      fastapi/dependencies/utils.py

152
fastapi/dependencies/utils.py

@ -119,7 +119,9 @@ def get_param_sub_dependant(
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency" assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
@ -140,7 +142,9 @@ def get_sub_dependant(
use_scopes: List[str] = [] use_scopes: List[str] = []
if isinstance(dependency, (OAuth2, OpenIdConnect)): if isinstance(dependency, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes use_scopes = security_scopes
security_requirement = SecurityRequirement(security_scheme=dependency, scopes=use_scopes) security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_dependant( sub_dependant = get_dependant(
path=path, path=path,
call=dependency, call=dependency,
@ -179,7 +183,9 @@ def get_flat_dependant(
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited: if skip_repeats and sub_dependant.cache_key in visited:
continue continue
flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) flat_sub = get_flat_dependant(
sub_dependant, skip_repeats=skip_repeats, visited=visited
)
flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.header_params.extend(flat_sub.header_params)
@ -209,10 +215,14 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
for param in fields: for param in fields:
query_extra_info[param] = dict(fields[param].__repr_args__()) query_extra_info[param] = dict(fields[param].__repr_args__())
if "alias" in query_extra_info[param]: if "alias" in query_extra_info[param]:
query_extra_info[query_extra_info[param]["alias"]] = dict(fields[param].__repr_args__()) query_extra_info[query_extra_info[param]["alias"]] = dict(
fields[param].__repr_args__()
)
alias_dict[query_extra_info[param]["alias"]] = param alias_dict[query_extra_info[param]["alias"]] = param
query_extra_info[param]["default"] = ( query_extra_info[param]["default"] = (
Required if getattr(fields[param], "required", False) else fields[param].default Required
if getattr(fields[param], "required", False)
else fields[param].default
) )
typed_params = [] typed_params = []
@ -306,7 +316,9 @@ def get_dependant(
type_annotation=type_annotation, type_annotation=type_annotation,
dependant=dependant, dependant=dependant,
): ):
assert param_field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}" assert (
param_field is None
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
continue continue
assert param_field is not None assert param_field is not None
if is_body_param(param_field=param_field, is_path_param=is_path_param): if is_body_param(param_field=param_field, is_path_param=is_path_param):
@ -316,7 +328,9 @@ def get_dependant(
return dependant return dependant
def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]: def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant
) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request): if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name dependant.request_param_name = param_name
return True return True
@ -348,17 +362,26 @@ def analyze_param(
field_info = None field_info = None
depends = None depends = None
type_annotation: Any = Any type_annotation: Any = Any
if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: if (
annotation is not inspect.Signature.empty
and get_origin(annotation) is Annotated
):
annotated_args = get_args(annotation) annotated_args = get_args(annotation)
type_annotation = annotated_args[0] type_annotation = annotated_args[0]
fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))] fastapi_annotations = [
arg
for arg in annotated_args[1:]
if isinstance(arg, (FieldInfo, params.Depends))
]
assert ( assert (
len(fastapi_annotations) <= 1 len(fastapi_annotations) <= 1
), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}"
fastapi_annotation = next(iter(fastapi_annotations), None) fastapi_annotation = next(iter(fastapi_annotations), None)
if isinstance(fastapi_annotation, FieldInfo): if isinstance(fastapi_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` below. # Copy `field_info` because we mutate `field_info.default` below.
field_info = copy_field_info(field_info=fastapi_annotation, annotation=annotation) field_info = copy_field_info(
field_info=fastapi_annotation, annotation=annotation
)
assert field_info.default is Undefined or field_info.default is Required, ( assert field_info.default is Undefined or field_info.default is Required, (
f"`{field_info.__class__.__name__}` default value cannot be set in" f"`{field_info.__class__.__name__}` default value cannot be set in"
f" `Annotated` for {param_name!r}. Set the default value with `=` instead." f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
@ -375,7 +398,8 @@ def analyze_param(
if isinstance(value, params.Depends): if isinstance(value, params.Depends):
assert depends is None, ( assert depends is None, (
"Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" "Cannot specify `Depends` in `Annotated` and default value"
f" together for {param_name!r}"
) )
assert field_info is None, ( assert field_info is None, (
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
@ -384,7 +408,8 @@ def analyze_param(
depends = value depends = value
elif isinstance(value, FieldInfo): elif isinstance(value, FieldInfo):
assert field_info is None, ( assert field_info is None, (
"Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" "Cannot specify FastAPI annotations in `Annotated` and default value"
f" together for {param_name!r}"
) )
field_info = value field_info = value
if PYDANTIC_V2: if PYDANTIC_V2:
@ -405,7 +430,9 @@ def analyze_param(
), ),
): ):
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}" assert (
field_info is None
), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
elif field_info is None and depends is None: elif field_info is None and depends is None:
default_value = value if value is not inspect.Signature.empty else Required default_value = value if value is not inspect.Signature.empty else Required
if is_path_param: if is_path_param:
@ -413,9 +440,9 @@ def analyze_param(
# parameter might sometimes be a path parameter and sometimes not. See # parameter might sometimes be a path parameter and sometimes not. See
# `tests/test_infer_param_optionality.py` for an example. # `tests/test_infer_param_optionality.py` for an example.
field_info = params.Path(annotation=type_annotation) field_info = params.Path(annotation=type_annotation)
elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation( elif is_uploadfile_or_nonable_uploadfile_annotation(
type_annotation type_annotation
): ) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=type_annotation, default=default_value) field_info = params.File(annotation=type_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation): elif not field_annotation_is_scalar(annotation=type_annotation):
field_info = params.Body(annotation=type_annotation, default=default_value) field_info = params.Body(annotation=type_annotation, default=default_value)
@ -426,9 +453,13 @@ def analyze_param(
if field_info is not None: if field_info is not None:
if is_path_param: if is_path_param:
assert isinstance(field_info, params.Path), ( assert isinstance(field_info, params.Path), (
f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" f"Cannot use `{field_info.__class__.__name__}` for path param"
f" {param_name!r}"
) )
elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None: elif (
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
field_info.in_ = params.ParamTypes.query field_info.in_ = params.ParamTypes.query
use_annotation = get_annotation_from_field_info( use_annotation = get_annotation_from_field_info(
type_annotation, type_annotation,
@ -454,11 +485,15 @@ def analyze_param(
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
if is_path_param: if is_path_param:
assert is_scalar_field(field=param_field), "Path params must be of one of the supported types" assert is_scalar_field(
field=param_field
), "Path params must be of one of the supported types"
return False return False
elif is_scalar_field(field=param_field): elif is_scalar_field(field=param_field):
return False return False
elif isinstance(param_field.field_info, (params.Query, params.Header)) and is_scalar_sequence_field(param_field): elif isinstance(
param_field.field_info, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
return False return False
else: else:
assert isinstance( assert isinstance(
@ -505,7 +540,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
return inspect.isgeneratorfunction(dunder_call) return inspect.isgeneratorfunction(dunder_call)
async def solve_generator(*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]) -> Any: async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
if is_gen_callable(call): if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
elif is_async_gen_callable(call): elif is_async_gen_callable(call):
@ -539,12 +576,19 @@ async def solve_dependencies(
sub_dependant: Dependant sub_dependant: Dependant
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides: if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call original_call = sub_dependant.call
call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant( use_sub_dependant = get_dependant(
path=use_path, path=use_path,
@ -578,7 +622,9 @@ async def solve_dependencies(
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
stack = request.scope.get("fastapi_astack") stack = request.scope.get("fastapi_astack")
assert isinstance(stack, AsyncExitStack) assert isinstance(stack, AsyncExitStack)
solved = await solve_generator(call=call, stack=stack, sub_values=sub_values) solved = await solve_generator(
call=call, stack=stack, sub_values=sub_values
)
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
solved = await call(**sub_values) solved = await call(**sub_values)
else: else:
@ -587,10 +633,18 @@ async def solve_dependencies(
values[sub_dependant.name] = solved values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache: if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params) path_values, path_errors = request_params_to_args(
query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params) dependant.path_params, request.path_params
header_values, header_errors = request_params_to_args(dependant.header_params, request.headers) )
cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies) query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values) values.update(path_values)
values.update(query_values) values.update(query_values)
values.update(header_values) values.update(header_values)
@ -618,7 +672,9 @@ async def solve_dependencies(
if dependant.response_param_name: if dependant.response_param_name:
values[dependant.response_param_name] = response values[dependant.response_param_name] = response
if dependant.security_scopes_param_name: if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.security_scopes) values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache return values, errors, background_tasks, response, dependency_cache
@ -629,12 +685,16 @@ def request_params_to_args(
values = {} values = {}
errors = [] errors = []
for field in required_params: for field in required_params:
if is_scalar_sequence_field(field) and isinstance(received_params, (QueryParams, Headers)): if is_scalar_sequence_field(field) and isinstance(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias) or field.default value = received_params.getlist(field.alias) or field.default
else: else:
value = received_params.get(field.alias) value = received_params.get(field.alias)
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, params.Param), "Params must be subclasses of Param" assert isinstance(
field_info, params.Param
), "Params must be subclasses of Param"
loc = (field_info.in_.value, field.alias) loc = (field_info.in_.value, field.alias)
if value is None: if value is None:
if field.required: if field.required:
@ -687,21 +747,35 @@ async def request_body_to_args(
if ( if (
value is None value is None
or (isinstance(field_info, params.Form) and value == "") or (isinstance(field_info, params.Form) and value == "")
or (isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0) or (
isinstance(field_info, params.Form)
and is_sequence_field(field)
and len(value) == 0
)
): ):
if field.required: if field.required:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))
else: else:
values[field.name] = deepcopy(field.default) values[field.name] = deepcopy(field.default)
continue continue
if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile): if (
isinstance(field_info, params.File)
and is_bytes_field(field)
and isinstance(value, UploadFile)
):
value = await value.read() value = await value.read()
elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value): elif (
is_bytes_sequence_field(field)
and isinstance(field_info, params.File)
and value_is_sequence(value)
):
# For types # For types
assert isinstance(value, sequence_types) # type: ignore[arg-type] assert isinstance(value, sequence_types) # type: ignore[arg-type]
results: List[Union[bytes, str]] = [] results: List[Union[bytes, str]] = []
async def process_fn(fn: Callable[[], Coroutine[Any, Any, Any]]) -> None: async def process_fn(
fn: Callable[[], Coroutine[Any, Any, Any]]
) -> None:
result = await fn() result = await fn()
results.append(result) # noqa: B023 results.append(result) # noqa: B023
@ -738,7 +812,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
for param in flat_dependant.body_params: for param in flat_dependant.body_params:
setattr(param.field_info, "embed", True) # noqa: B010 setattr(param.field_info, "embed", True) # noqa: B010
model_name = "Body_" + name model_name = "Body_" + name
BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name) BodyModel = create_body_model(
fields=flat_dependant.body_params, model_name=model_name
)
required = any(True for f in flat_dependant.body_params if f.required) required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: Dict[str, Any] = { BodyFieldInfo_kwargs: Dict[str, Any] = {
"annotation": BodyModel, "annotation": BodyModel,
@ -754,7 +830,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
BodyFieldInfo = params.Body BodyFieldInfo = params.Body
body_param_media_types = [ body_param_media_types = [
f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) f.field_info.media_type
for f in flat_dependant.body_params
if isinstance(f.field_info, params.Body)
] ]
if len(set(body_param_media_types)) == 1: if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]

Loading…
Cancel
Save