From 8dbdd3d89cedaa17d031070425abc69d6253013f Mon Sep 17 00:00:00 2001 From: Zoltan Fodor Date: Tue, 3 Jun 2025 17:34:31 +0300 Subject: [PATCH 1/2] Parallelize dependency resolution. (#639) --- fastapi/dependencies/utils.py | 182 ++++++++++++++++++++++++++------- tests/test_dependency_cache.py | 2 +- 2 files changed, 147 insertions(+), 37 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 84dfa4d03..818fd93ac 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,4 +1,6 @@ +import asyncio import inspect +from asyncio import Future from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -569,6 +571,33 @@ class SolvedDependency: dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] +class DependencySolveException(Exception): + def __init__(self, errors: List[Any]): + super().__init__(str(errors)) + self.errors = errors + + +def is_context_sensitive(dependant: Dependant) -> bool: + if dependant.call is not None and ( + is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call) + ): + return True + return any(is_context_sensitive(sub) for sub in dependant.dependencies) + + +def is_context_with_background_task(dependant: Dependant) -> bool: + if dependant.background_tasks_param_name: + return True + return any(is_context_with_background_task(sub) for sub in dependant.dependencies) + + +def silence_future_exception(fut: "Future[Any]") -> None: + try: + fut.exception() + except BaseException: + pass # Silences the warning + + async def solve_dependencies( *, request: Union[Request, WebSocket], @@ -583,19 +612,25 @@ async def solve_dependencies( ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] + if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore + dependency_cache = dependency_cache or {} - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: + + if background_tasks is None and is_context_with_background_task(dependant): + background_tasks = BackgroundTasks() + + async def resolve_sub_dependant( + sub_dependant: Dependant, + ) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key - ) + cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) call = sub_dependant.call use_sub_dependant = sub_dependant + if ( dependency_overrides_provider and dependency_overrides_provider.dependency_overrides @@ -612,36 +647,110 @@ async def solve_dependencies( security_scopes=sub_dependant.security_scopes, ) - solved_result = await solve_dependencies( - request=request, - dependant=use_sub_dependant, - body=body, - background_tasks=background_tasks, - response=response, - dependency_overrides_provider=dependency_overrides_provider, - dependency_cache=dependency_cache, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, - ) - background_tasks = solved_result.background_tasks - dependency_cache.update(solved_result.dependency_cache) - if solved_result.errors: - errors.extend(solved_result.errors) - continue - if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: - solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): - solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + def resolve_cached() -> Optional["Future[Any]"]: + if not sub_dependant.use_cache: + return None + + if cache_key not in dependency_cache: + future = asyncio.get_event_loop().create_future() + # Ensures ignored exceptions are not logged with warning, as we only raise the first one. + future.add_done_callback(silence_future_exception) + dependency_cache[cache_key] = future + return None + + cached = dependency_cache[cache_key] + + if isinstance(cached, Future): + return cached + + future = Future() + future.set_result(cached) + return future + + async def resolve_value() -> Any: + solved_result = await solve_dependencies( + request=request, + dependant=use_sub_dependant, + body=body, + background_tasks=background_tasks, + response=response, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, ) - elif is_coroutine_callable(call): - solved = await call(**solved_result.values) + + if solved_result.errors: + raise DependencySolveException(solved_result.errors) + + if is_gen_callable(call) or is_async_gen_callable(call): + return await solve_generator( + call=call, + stack=async_exit_stack, + sub_values=solved_result.values, + ) + elif is_coroutine_callable(call): + return await call(**solved_result.values) + else: + return await run_in_threadpool(call, **solved_result.values) + + task = resolve_cached() or resolve_value() + + def ensure_cache(value: Optional[Any]) -> None: + if ( + cache_key not in dependency_cache + or not isinstance(dependency_cache[cache_key], Future) + or dependency_cache[cache_key].done() + ): + return + + if isinstance(value, BaseException): + dependency_cache[cache_key].set_exception(value) + else: + dependency_cache[cache_key].set_result(value) + + try: + resolved = await task + except DependencySolveException as exc: + ensure_cache(None) + return sub_dependant.name, None, exc.errors, None + except BaseException as resolution_exc: + ensure_cache(resolution_exc) + return sub_dependant.name, None, None, resolution_exc + + ensure_cache(resolved) + return sub_dependant.name, resolved, None, None + + sequential_deps = [] + parallel_deps = [] + for sub in dependant.dependencies: + if is_context_sensitive(sub): + sequential_deps.append(sub) else: - solved = await run_in_threadpool(call, **solved_result.values) - if sub_dependant.name is not None: - values[sub_dependant.name] = solved - if sub_dependant.cache_key not in dependency_cache: - dependency_cache[sub_dependant.cache_key] = solved + parallel_deps.append(sub) + + results = [] + + for sub in sequential_deps: + result = await resolve_sub_dependant(sub) + results.append(result) + + if parallel_deps: + parallel_results = await asyncio.gather( + *[resolve_sub_dependant(sub) for sub in parallel_deps] + ) + results.extend(parallel_results) + + for name, result, sub_errors, sub_exception in results: + # Ensures order of exception based on dependency order. + if sub_exception: + raise sub_exception + if sub_errors: + errors.extend(sub_errors) + continue + if name is not None: + values[name] = result + path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params ) @@ -659,17 +768,19 @@ async def solve_dependencies( values.update(header_values) values.update(cookie_values) errors += path_errors + query_errors + header_errors + cookie_errors + if dependant.body_params: ( body_values, body_errors, - ) = await request_body_to_args( # body_params checked above + ) = await request_body_to_args( body_fields=dependant.body_params, received_body=body, embed_body_fields=embed_body_fields, ) values.update(body_values) errors.extend(body_errors) + if dependant.http_connection_param_name: values[dependant.http_connection_param_name] = request if dependant.request_param_name and isinstance(request, Request): @@ -677,8 +788,6 @@ async def solve_dependencies( elif dependant.websocket_param_name and isinstance(request, WebSocket): values[dependant.websocket_param_name] = request if dependant.background_tasks_param_name: - if background_tasks is None: - background_tasks = BackgroundTasks() values[dependant.background_tasks_param_name] = background_tasks if dependant.response_param_name: values[dependant.response_param_name] = response @@ -686,6 +795,7 @@ async def solve_dependencies( values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) + return SolvedDependency( values=values, errors=errors, diff --git a/tests/test_dependency_cache.py b/tests/test_dependency_cache.py index 08fb9b74f..9d3a0ae9d 100644 --- a/tests/test_dependency_cache.py +++ b/tests/test_dependency_cache.py @@ -29,7 +29,7 @@ async def get_sub_counter( @app.get("/sub-counter-no-cache/") async def get_sub_counter_no_cache( - subcount: int = Depends(super_dep), + subcount: int = Depends(dep_counter), count: int = Depends(dep_counter, use_cache=False), ): return {"counter": count, "subcounter": subcount} From d518b23bef56371215ab6bd6ad15c2084233d548 Mon Sep 17 00:00:00 2001 From: Zoltan Fodor Date: Mon, 30 Jun 2025 11:44:34 +0300 Subject: [PATCH 2/2] Allow dependency to opt out from parallelization. (#639) --- fastapi/dependencies/models.py | 1 + fastapi/dependencies/utils.py | 52 ++++++++++++++++++++++------------ fastapi/params.py | 12 ++++++-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..614910308 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -32,6 +32,7 @@ class Dependant: use_cache: bool = True path: Optional[str] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) + parallelizable: bool = True def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 818fd93ac..fe58dd641 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -167,6 +167,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + parallelizable=depends.parallelizable, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -194,6 +195,7 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), use_cache=dependant.use_cache, + parallelizable=dependant.parallelizable, path=dependant.path, ) for sub_dependant in dependant.dependencies: @@ -271,6 +273,7 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + parallelizable: bool = True, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -281,6 +284,7 @@ def get_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + parallelizable=parallelizable, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -578,8 +582,9 @@ class DependencySolveException(Exception): def is_context_sensitive(dependant: Dependant) -> bool: - if dependant.call is not None and ( - is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call) + if dependant.parallelizable is False or ( + dependant.call is not None + and (is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call)) ): return True return any(is_context_sensitive(sub) for sub in dependant.dependencies) @@ -721,6 +726,21 @@ async def solve_dependencies( ensure_cache(resolved) return sub_dependant.name, resolved, None, None + def unpack_results( + results: list[ + Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]] + ], + ) -> None: + for name, value, sub_errors, sub_exception in results: + # Ensures order of exception based on dependency order. + if sub_exception: + raise sub_exception + if sub_errors: + errors.extend(sub_errors) + continue + if name is not None: + values[name] = value + sequential_deps = [] parallel_deps = [] for sub in dependant.dependencies: @@ -729,27 +749,23 @@ async def solve_dependencies( else: parallel_deps.append(sub) - results = [] - + sequential_results: list[ + Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]] + ] = [] for sub in sequential_deps: - result = await resolve_sub_dependant(sub) - results.append(result) + s_result = await resolve_sub_dependant(sub) + sequential_results.append(s_result) + unpack_results(sequential_results) + parallel_results: list[ + Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]] + ] = [] if parallel_deps: - parallel_results = await asyncio.gather( + p_result = await asyncio.gather( *[resolve_sub_dependant(sub) for sub in parallel_deps] ) - results.extend(parallel_results) - - for name, result, sub_errors, sub_exception in results: - # Ensures order of exception based on dependency order. - if sub_exception: - raise sub_exception - if sub_errors: - errors.extend(sub_errors) - continue - if name is not None: - values[name] = result + parallel_results.extend(p_result) + unpack_results(parallel_results) path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params diff --git a/fastapi/params.py b/fastapi/params.py index 8f5601dd3..17c3642b3 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -763,10 +763,15 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + parallelizable: bool = True, ): self.dependency = dependency self.use_cache = use_cache + self.parallelizable = parallelizable def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) @@ -781,6 +786,9 @@ class Security(Depends): *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, + parallelizable: bool = True, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__( + dependency=dependency, use_cache=use_cache, parallelizable=parallelizable + ) self.scopes = scopes or []