Browse Source

Parallelize dependency resolution. (#639)

pull/13756/head
Zoltan Fodor 2 months ago
parent
commit
8dbdd3d89c
  1. 182
      fastapi/dependencies/utils.py
  2. 2
      tests/test_dependency_cache.py

182
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import asyncio
import inspect
from asyncio import Future
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
@ -569,6 +571,33 @@ class SolvedDependency:
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
class DependencySolveException(Exception):
def __init__(self, errors: List[Any]):
super().__init__(str(errors))
self.errors = errors
def is_context_sensitive(dependant: Dependant) -> bool:
if dependant.call is not None and (
is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call)
):
return True
return any(is_context_sensitive(sub) for sub in dependant.dependencies)
def is_context_with_background_task(dependant: Dependant) -> bool:
if dependant.background_tasks_param_name:
return True
return any(is_context_with_background_task(sub) for sub in dependant.dependencies)
def silence_future_exception(fut: "Future[Any]") -> None:
try:
fut.exception()
except BaseException:
pass # Silences the warning
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
@ -583,19 +612,25 @@ async def solve_dependencies(
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
if response is None:
response = Response()
del response.headers["content-length"]
response.status_code = None # type: ignore
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
if background_tasks is None and is_context_with_background_task(dependant):
background_tasks = BackgroundTasks()
async def resolve_sub_dependant(
sub_dependant: Dependant,
) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
@ -612,36 +647,110 @@ async def solve_dependencies(
security_scopes=sub_dependant.security_scopes,
)
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors:
errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
def resolve_cached() -> Optional["Future[Any]"]:
if not sub_dependant.use_cache:
return None
if cache_key not in dependency_cache:
future = asyncio.get_event_loop().create_future()
# Ensures ignored exceptions are not logged with warning, as we only raise the first one.
future.add_done_callback(silence_future_exception)
dependency_cache[cache_key] = future
return None
cached = dependency_cache[cache_key]
if isinstance(cached, Future):
return cached
future = Future()
future.set_result(cached)
return future
async def resolve_value() -> Any:
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
if solved_result.errors:
raise DependencySolveException(solved_result.errors)
if is_gen_callable(call) or is_async_gen_callable(call):
return await solve_generator(
call=call,
stack=async_exit_stack,
sub_values=solved_result.values,
)
elif is_coroutine_callable(call):
return await call(**solved_result.values)
else:
return await run_in_threadpool(call, **solved_result.values)
task = resolve_cached() or resolve_value()
def ensure_cache(value: Optional[Any]) -> None:
if (
cache_key not in dependency_cache
or not isinstance(dependency_cache[cache_key], Future)
or dependency_cache[cache_key].done()
):
return
if isinstance(value, BaseException):
dependency_cache[cache_key].set_exception(value)
else:
dependency_cache[cache_key].set_result(value)
try:
resolved = await task
except DependencySolveException as exc:
ensure_cache(None)
return sub_dependant.name, None, exc.errors, None
except BaseException as resolution_exc:
ensure_cache(resolution_exc)
return sub_dependant.name, None, None, resolution_exc
ensure_cache(resolved)
return sub_dependant.name, resolved, None, None
sequential_deps = []
parallel_deps = []
for sub in dependant.dependencies:
if is_context_sensitive(sub):
sequential_deps.append(sub)
else:
solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
parallel_deps.append(sub)
results = []
for sub in sequential_deps:
result = await resolve_sub_dependant(sub)
results.append(result)
if parallel_deps:
parallel_results = await asyncio.gather(
*[resolve_sub_dependant(sub) for sub in parallel_deps]
)
results.extend(parallel_results)
for name, result, sub_errors, sub_exception in results:
# Ensures order of exception based on dependency order.
if sub_exception:
raise sub_exception
if sub_errors:
errors.extend(sub_errors)
continue
if name is not None:
values[name] = result
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
@ -659,17 +768,19 @@ async def solve_dependencies(
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
(
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
) = await request_body_to_args(
body_fields=dependant.body_params,
received_body=body,
embed_body_fields=embed_body_fields,
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
@ -677,8 +788,6 @@ async def solve_dependencies(
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
@ -686,6 +795,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return SolvedDependency(
values=values,
errors=errors,

2
tests/test_dependency_cache.py

@ -29,7 +29,7 @@ async def get_sub_counter(
@app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep),
subcount: int = Depends(dep_counter),
count: int = Depends(dep_counter, use_cache=False),
):
return {"counter": count, "subcounter": subcount}

Loading…
Cancel
Save