From 21532f792836f65f6ffe0d752e7dfc5b62aaa78d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:56:26 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/dependencies/utils.py | 152 +++++++++++++++++++++++++--------- 1 file changed, 115 insertions(+), 37 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 95fd5b033..caa18e118 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -119,7 +119,9 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: - assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency" + assert callable( + depends.dependency + ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) @@ -140,7 +142,9 @@ def get_sub_dependant( use_scopes: List[str] = [] if isinstance(dependency, (OAuth2, OpenIdConnect)): use_scopes = security_scopes - security_requirement = SecurityRequirement(security_scheme=dependency, scopes=use_scopes) + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=use_scopes + ) sub_dependant = get_dependant( path=path, call=dependency, @@ -179,7 +183,9 @@ def get_flat_dependant( for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue - flat_sub = get_flat_dependant(sub_dependant, skip_repeats=skip_repeats, visited=visited) + flat_sub = get_flat_dependant( + sub_dependant, skip_repeats=skip_repeats, visited=visited + ) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) @@ -209,10 +215,14 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: for param in fields: query_extra_info[param] = dict(fields[param].__repr_args__()) if "alias" in query_extra_info[param]: - query_extra_info[query_extra_info[param]["alias"]] = dict(fields[param].__repr_args__()) + query_extra_info[query_extra_info[param]["alias"]] = dict( + fields[param].__repr_args__() + ) alias_dict[query_extra_info[param]["alias"]] = param query_extra_info[param]["default"] = ( - Required if getattr(fields[param], "required", False) else fields[param].default + Required + if getattr(fields[param], "required", False) + else fields[param].default ) typed_params = [] @@ -306,7 +316,9 @@ def get_dependant( type_annotation=type_annotation, dependant=dependant, ): - assert param_field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}" + assert ( + param_field is None + ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_field is not None if is_body_param(param_field=param_field, is_path_param=is_path_param): @@ -316,7 +328,9 @@ def get_dependant( return dependant -def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]: +def add_non_field_param_to_dependency( + *, param_name: str, type_annotation: Any, dependant: Dependant +) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name return True @@ -348,17 +362,26 @@ def analyze_param( field_info = None depends = None type_annotation: Any = Any - if annotation is not inspect.Signature.empty and get_origin(annotation) is Annotated: + if ( + annotation is not inspect.Signature.empty + and get_origin(annotation) is Annotated + ): annotated_args = get_args(annotation) type_annotation = annotated_args[0] - fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))] + fastapi_annotations = [ + arg + for arg in annotated_args[1:] + if isinstance(arg, (FieldInfo, params.Depends)) + ] assert ( len(fastapi_annotations) <= 1 ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" fastapi_annotation = next(iter(fastapi_annotations), None) if isinstance(fastapi_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` below. - field_info = copy_field_info(field_info=fastapi_annotation, annotation=annotation) + field_info = copy_field_info( + field_info=fastapi_annotation, annotation=annotation + ) assert field_info.default is Undefined or field_info.default is Required, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." @@ -375,7 +398,8 @@ def analyze_param( if isinstance(value, params.Depends): assert depends is None, ( - "Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" + "Cannot specify `Depends` in `Annotated` and default value" + f" together for {param_name!r}" ) assert field_info is None, ( "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" @@ -384,7 +408,8 @@ def analyze_param( depends = value elif isinstance(value, FieldInfo): assert field_info is None, ( - "Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" + "Cannot specify FastAPI annotations in `Annotated` and default value" + f" together for {param_name!r}" ) field_info = value if PYDANTIC_V2: @@ -405,7 +430,9 @@ def analyze_param( ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" - assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}" + assert ( + field_info is None + ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" elif field_info is None and depends is None: default_value = value if value is not inspect.Signature.empty else Required if is_path_param: @@ -413,9 +440,9 @@ def analyze_param( # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. field_info = params.Path(annotation=type_annotation) - elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation( + elif is_uploadfile_or_nonable_uploadfile_annotation( type_annotation - ): + ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=type_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): field_info = params.Body(annotation=type_annotation, default=default_value) @@ -426,9 +453,13 @@ def analyze_param( if field_info is not None: if is_path_param: assert isinstance(field_info, params.Path), ( - f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" + f"Cannot use `{field_info.__class__.__name__}` for path param" + f" {param_name!r}" ) - elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None: + elif ( + isinstance(field_info, params.Param) + and getattr(field_info, "in_", None) is None + ): field_info.in_ = params.ParamTypes.query use_annotation = get_annotation_from_field_info( type_annotation, @@ -454,11 +485,15 @@ def analyze_param( def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: if is_path_param: - assert is_scalar_field(field=param_field), "Path params must be of one of the supported types" + assert is_scalar_field( + field=param_field + ), "Path params must be of one of the supported types" return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (params.Query, params.Header)) and is_scalar_sequence_field(param_field): + elif isinstance( + param_field.field_info, (params.Query, params.Header) + ) and is_scalar_sequence_field(param_field): return False else: assert isinstance( @@ -505,7 +540,9 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: return inspect.isgeneratorfunction(dunder_call) -async def solve_generator(*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]) -> Any: +async def solve_generator( + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] +) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) elif is_async_gen_callable(call): @@ -539,12 +576,19 @@ async def solve_dependencies( sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) + sub_dependant.cache_key = cast( + Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key + ) call = sub_dependant.call use_sub_dependant = sub_dependant - if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides: + if ( + dependency_overrides_provider + and dependency_overrides_provider.dependency_overrides + ): original_call = sub_dependant.call - call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) + call = getattr( + dependency_overrides_provider, "dependency_overrides", {} + ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, @@ -578,7 +622,9 @@ async def solve_dependencies( elif is_gen_callable(call) or is_async_gen_callable(call): stack = request.scope.get("fastapi_astack") assert isinstance(stack, AsyncExitStack) - solved = await solve_generator(call=call, stack=stack, sub_values=sub_values) + solved = await solve_generator( + call=call, stack=stack, sub_values=sub_values + ) elif is_coroutine_callable(call): solved = await call(**sub_values) else: @@ -587,10 +633,18 @@ async def solve_dependencies( values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: dependency_cache[sub_dependant.cache_key] = solved - path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params) - query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params) - header_values, header_errors = request_params_to_args(dependant.header_params, request.headers) - cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies) + path_values, path_errors = request_params_to_args( + dependant.path_params, request.path_params + ) + query_values, query_errors = request_params_to_args( + dependant.query_params, request.query_params + ) + header_values, header_errors = request_params_to_args( + dependant.header_params, request.headers + ) + cookie_values, cookie_errors = request_params_to_args( + dependant.cookie_params, request.cookies + ) values.update(path_values) values.update(query_values) values.update(header_values) @@ -618,7 +672,9 @@ async def solve_dependencies( if dependant.response_param_name: values[dependant.response_param_name] = response if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.security_scopes) + values[dependant.security_scopes_param_name] = SecurityScopes( + scopes=dependant.security_scopes + ) return values, errors, background_tasks, response, dependency_cache @@ -629,12 +685,16 @@ def request_params_to_args( values = {} errors = [] for field in required_params: - if is_scalar_sequence_field(field) and isinstance(received_params, (QueryParams, Headers)): + if is_scalar_sequence_field(field) and isinstance( + received_params, (QueryParams, Headers) + ): value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) field_info = field.field_info - assert isinstance(field_info, params.Param), "Params must be subclasses of Param" + assert isinstance( + field_info, params.Param + ), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) if value is None: if field.required: @@ -687,21 +747,35 @@ async def request_body_to_args( if ( value is None or (isinstance(field_info, params.Form) and value == "") - or (isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0) + or ( + isinstance(field_info, params.Form) + and is_sequence_field(field) + and len(value) == 0 + ) ): if field.required: errors.append(get_missing_field_error(loc)) else: values[field.name] = deepcopy(field.default) continue - if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile): + if ( + isinstance(field_info, params.File) + and is_bytes_field(field) + and isinstance(value, UploadFile) + ): value = await value.read() - elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value): + elif ( + is_bytes_sequence_field(field) + and isinstance(field_info, params.File) + and value_is_sequence(value) + ): # For types assert isinstance(value, sequence_types) # type: ignore[arg-type] results: List[Union[bytes, str]] = [] - async def process_fn(fn: Callable[[], Coroutine[Any, Any, Any]]) -> None: + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]] + ) -> None: result = await fn() results.append(result) # noqa: B023 @@ -738,7 +812,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: for param in flat_dependant.body_params: setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name - BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name) + BodyModel = create_body_model( + fields=flat_dependant.body_params, model_name=model_name + ) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = { "annotation": BodyModel, @@ -754,7 +830,9 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: BodyFieldInfo = params.Body body_param_media_types = [ - f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) + f.field_info.media_type + for f in flat_dependant.body_params + if isinstance(f.field_info, params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]