Browse Source

Parallelize dependency resolution. (#639)

pull/13756/head
Zoltan Fodor 2 months ago
parent
commit
8dbdd3d89c
  1. 156
      fastapi/dependencies/utils.py
  2. 2
      tests/test_dependency_cache.py

156
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import asyncio
import inspect import inspect
from asyncio import Future
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -569,6 +571,33 @@ class SolvedDependency:
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
class DependencySolveException(Exception):
def __init__(self, errors: List[Any]):
super().__init__(str(errors))
self.errors = errors
def is_context_sensitive(dependant: Dependant) -> bool:
if dependant.call is not None and (
is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call)
):
return True
return any(is_context_sensitive(sub) for sub in dependant.dependencies)
def is_context_with_background_task(dependant: Dependant) -> bool:
if dependant.background_tasks_param_name:
return True
return any(is_context_with_background_task(sub) for sub in dependant.dependencies)
def silence_future_exception(fut: "Future[Any]") -> None:
try:
fut.exception()
except BaseException:
pass # Silences the warning
async def solve_dependencies( async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
@ -583,19 +612,25 @@ async def solve_dependencies(
) -> SolvedDependency: ) -> SolvedDependency:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[Any] = []
if response is None: if response is None:
response = Response() response = Response()
del response.headers["content-length"] del response.headers["content-length"]
response.status_code = None # type: ignore response.status_code = None # type: ignore
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies: if background_tasks is None and is_context_with_background_task(dependant):
background_tasks = BackgroundTasks()
async def resolve_sub_dependant(
sub_dependant: Dependant,
) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast( cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key)
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if ( if (
dependency_overrides_provider dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides and dependency_overrides_provider.dependency_overrides
@ -612,6 +647,27 @@ async def solve_dependencies(
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
) )
def resolve_cached() -> Optional["Future[Any]"]:
if not sub_dependant.use_cache:
return None
if cache_key not in dependency_cache:
future = asyncio.get_event_loop().create_future()
# Ensures ignored exceptions are not logged with warning, as we only raise the first one.
future.add_done_callback(silence_future_exception)
dependency_cache[cache_key] = future
return None
cached = dependency_cache[cache_key]
if isinstance(cached, Future):
return cached
future = Future()
future.set_result(cached)
return future
async def resolve_value() -> Any:
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
request=request, request=request,
dependant=use_sub_dependant, dependant=use_sub_dependant,
@ -623,25 +679,78 @@ async def solve_dependencies(
async_exit_stack=async_exit_stack, async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors: if solved_result.errors:
errors.extend(solved_result.errors) raise DependencySolveException(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: if is_gen_callable(call) or is_async_gen_callable(call):
solved = dependency_cache[sub_dependant.cache_key] return await solve_generator(
elif is_gen_callable(call) or is_async_gen_callable(call): call=call,
solved = await solve_generator( stack=async_exit_stack,
call=call, stack=async_exit_stack, sub_values=solved_result.values sub_values=solved_result.values,
) )
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
solved = await call(**solved_result.values) return await call(**solved_result.values)
else:
return await run_in_threadpool(call, **solved_result.values)
task = resolve_cached() or resolve_value()
def ensure_cache(value: Optional[Any]) -> None:
if (
cache_key not in dependency_cache
or not isinstance(dependency_cache[cache_key], Future)
or dependency_cache[cache_key].done()
):
return
if isinstance(value, BaseException):
dependency_cache[cache_key].set_exception(value)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) dependency_cache[cache_key].set_result(value)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved try:
if sub_dependant.cache_key not in dependency_cache: resolved = await task
dependency_cache[sub_dependant.cache_key] = solved except DependencySolveException as exc:
ensure_cache(None)
return sub_dependant.name, None, exc.errors, None
except BaseException as resolution_exc:
ensure_cache(resolution_exc)
return sub_dependant.name, None, None, resolution_exc
ensure_cache(resolved)
return sub_dependant.name, resolved, None, None
sequential_deps = []
parallel_deps = []
for sub in dependant.dependencies:
if is_context_sensitive(sub):
sequential_deps.append(sub)
else:
parallel_deps.append(sub)
results = []
for sub in sequential_deps:
result = await resolve_sub_dependant(sub)
results.append(result)
if parallel_deps:
parallel_results = await asyncio.gather(
*[resolve_sub_dependant(sub) for sub in parallel_deps]
)
results.extend(parallel_results)
for name, result, sub_errors, sub_exception in results:
# Ensures order of exception based on dependency order.
if sub_exception:
raise sub_exception
if sub_errors:
errors.extend(sub_errors)
continue
if name is not None:
values[name] = result
path_values, path_errors = request_params_to_args( path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params dependant.path_params, request.path_params
) )
@ -659,17 +768,19 @@ async def solve_dependencies(
values.update(header_values) values.update(header_values)
values.update(cookie_values) values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params: if dependant.body_params:
( (
body_values, body_values,
body_errors, body_errors,
) = await request_body_to_args( # body_params checked above ) = await request_body_to_args(
body_fields=dependant.body_params, body_fields=dependant.body_params,
received_body=body, received_body=body,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
values.update(body_values) values.update(body_values)
errors.extend(body_errors) errors.extend(body_errors)
if dependant.http_connection_param_name: if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request): if dependant.request_param_name and isinstance(request, Request):
@ -677,8 +788,6 @@ async def solve_dependencies(
elif dependant.websocket_param_name and isinstance(request, WebSocket): elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name: if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name: if dependant.response_param_name:
values[dependant.response_param_name] = response values[dependant.response_param_name] = response
@ -686,6 +795,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes( values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes scopes=dependant.security_scopes
) )
return SolvedDependency( return SolvedDependency(
values=values, values=values,
errors=errors, errors=errors,

2
tests/test_dependency_cache.py

@ -29,7 +29,7 @@ async def get_sub_counter(
@app.get("/sub-counter-no-cache/") @app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache( async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep), subcount: int = Depends(dep_counter),
count: int = Depends(dep_counter, use_cache=False), count: int = Depends(dep_counter, use_cache=False),
): ):
return {"counter": count, "subcounter": subcount} return {"counter": count, "subcounter": subcount}

Loading…
Cancel
Save