diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..614910308 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -32,6 +32,7 @@ class Dependant: use_cache: bool = True path: Optional[str] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) + parallelizable: bool = True def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 818fd93ac..fe58dd641 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -167,6 +167,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + parallelizable=depends.parallelizable, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -194,6 +195,7 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), use_cache=dependant.use_cache, + parallelizable=dependant.parallelizable, path=dependant.path, ) for sub_dependant in dependant.dependencies: @@ -271,6 +273,7 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + parallelizable: bool = True, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -281,6 +284,7 @@ def get_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + parallelizable=parallelizable, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -578,8 +582,9 @@ class DependencySolveException(Exception): def is_context_sensitive(dependant: Dependant) -> bool: - if dependant.call is not None and ( - is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call) + if dependant.parallelizable is False or ( + dependant.call is not None + and (is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call)) ): return True return any(is_context_sensitive(sub) for sub in dependant.dependencies) @@ -721,6 +726,21 @@ async def solve_dependencies( ensure_cache(resolved) return sub_dependant.name, resolved, None, None + def unpack_results( + results: list[ + Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]] + ], + ) -> None: + for name, value, sub_errors, sub_exception in results: + # Ensures order of exception based on dependency order. + if sub_exception: + raise sub_exception + if sub_errors: + errors.extend(sub_errors) + continue + if name is not None: + values[name] = value + sequential_deps = [] parallel_deps = [] for sub in dependant.dependencies: @@ -729,27 +749,23 @@ async def solve_dependencies( else: parallel_deps.append(sub) - results = [] - + sequential_results: list[ + Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]] + ] = [] for sub in sequential_deps: - result = await resolve_sub_dependant(sub) - results.append(result) + s_result = await resolve_sub_dependant(sub) + sequential_results.append(s_result) + unpack_results(sequential_results) + parallel_results: list[ + Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]] + ] = [] if parallel_deps: - parallel_results = await asyncio.gather( + p_result = await asyncio.gather( *[resolve_sub_dependant(sub) for sub in parallel_deps] ) - results.extend(parallel_results) - - for name, result, sub_errors, sub_exception in results: - # Ensures order of exception based on dependency order. - if sub_exception: - raise sub_exception - if sub_errors: - errors.extend(sub_errors) - continue - if name is not None: - values[name] = result + parallel_results.extend(p_result) + unpack_results(parallel_results) path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params diff --git a/fastapi/params.py b/fastapi/params.py index 8f5601dd3..17c3642b3 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -763,10 +763,15 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + parallelizable: bool = True, ): self.dependency = dependency self.use_cache = use_cache + self.parallelizable = parallelizable def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) @@ -781,6 +786,9 @@ class Security(Depends): *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, + parallelizable: bool = True, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__( + dependency=dependency, use_cache=use_cache, parallelizable=parallelizable + ) self.scopes = scopes or []