Browse Source

Parallelize dependency resolution. (#639)

pull/13756/head
Zoltan Fodor 2 months ago
parent
commit
8dbdd3d89c
  1. 182
      fastapi/dependencies/utils.py
  2. 2
      tests/test_dependency_cache.py

182
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import asyncio
import inspect import inspect
from asyncio import Future
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -569,6 +571,33 @@ class SolvedDependency:
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
class DependencySolveException(Exception):
def __init__(self, errors: List[Any]):
super().__init__(str(errors))
self.errors = errors
def is_context_sensitive(dependant: Dependant) -> bool:
if dependant.call is not None and (
is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call)
):
return True
return any(is_context_sensitive(sub) for sub in dependant.dependencies)
def is_context_with_background_task(dependant: Dependant) -> bool:
if dependant.background_tasks_param_name:
return True
return any(is_context_with_background_task(sub) for sub in dependant.dependencies)
def silence_future_exception(fut: "Future[Any]") -> None:
try:
fut.exception()
except BaseException:
pass # Silences the warning
async def solve_dependencies( async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
@ -583,19 +612,25 @@ async def solve_dependencies(
) -> SolvedDependency: ) -> SolvedDependency:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[Any] = []
if response is None: if response is None:
response = Response() response = Response()
del response.headers["content-length"] del response.headers["content-length"]
response.status_code = None # type: ignore response.status_code = None # type: ignore
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies: if background_tasks is None and is_context_with_background_task(dependant):
background_tasks = BackgroundTasks()
async def resolve_sub_dependant(
sub_dependant: Dependant,
) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast( cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key)
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if ( if (
dependency_overrides_provider dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides and dependency_overrides_provider.dependency_overrides
@ -612,36 +647,110 @@ async def solve_dependencies(
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
) )
solved_result = await solve_dependencies( def resolve_cached() -> Optional["Future[Any]"]:
request=request, if not sub_dependant.use_cache:
dependant=use_sub_dependant, return None
body=body,
background_tasks=background_tasks, if cache_key not in dependency_cache:
response=response, future = asyncio.get_event_loop().create_future()
dependency_overrides_provider=dependency_overrides_provider, # Ensures ignored exceptions are not logged with warning, as we only raise the first one.
dependency_cache=dependency_cache, future.add_done_callback(silence_future_exception)
async_exit_stack=async_exit_stack, dependency_cache[cache_key] = future
embed_body_fields=embed_body_fields, return None
)
background_tasks = solved_result.background_tasks cached = dependency_cache[cache_key]
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors: if isinstance(cached, Future):
errors.extend(solved_result.errors) return cached
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: future = Future()
solved = dependency_cache[sub_dependant.cache_key] future.set_result(cached)
elif is_gen_callable(call) or is_async_gen_callable(call): return future
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values async def resolve_value() -> Any:
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
) )
elif is_coroutine_callable(call):
solved = await call(**solved_result.values) if solved_result.errors:
raise DependencySolveException(solved_result.errors)
if is_gen_callable(call) or is_async_gen_callable(call):
return await solve_generator(
call=call,
stack=async_exit_stack,
sub_values=solved_result.values,
)
elif is_coroutine_callable(call):
return await call(**solved_result.values)
else:
return await run_in_threadpool(call, **solved_result.values)
task = resolve_cached() or resolve_value()
def ensure_cache(value: Optional[Any]) -> None:
if (
cache_key not in dependency_cache
or not isinstance(dependency_cache[cache_key], Future)
or dependency_cache[cache_key].done()
):
return
if isinstance(value, BaseException):
dependency_cache[cache_key].set_exception(value)
else:
dependency_cache[cache_key].set_result(value)
try:
resolved = await task
except DependencySolveException as exc:
ensure_cache(None)
return sub_dependant.name, None, exc.errors, None
except BaseException as resolution_exc:
ensure_cache(resolution_exc)
return sub_dependant.name, None, None, resolution_exc
ensure_cache(resolved)
return sub_dependant.name, resolved, None, None
sequential_deps = []
parallel_deps = []
for sub in dependant.dependencies:
if is_context_sensitive(sub):
sequential_deps.append(sub)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) parallel_deps.append(sub)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved results = []
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved for sub in sequential_deps:
result = await resolve_sub_dependant(sub)
results.append(result)
if parallel_deps:
parallel_results = await asyncio.gather(
*[resolve_sub_dependant(sub) for sub in parallel_deps]
)
results.extend(parallel_results)
for name, result, sub_errors, sub_exception in results:
# Ensures order of exception based on dependency order.
if sub_exception:
raise sub_exception
if sub_errors:
errors.extend(sub_errors)
continue
if name is not None:
values[name] = result
path_values, path_errors = request_params_to_args( path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params dependant.path_params, request.path_params
) )
@ -659,17 +768,19 @@ async def solve_dependencies(
values.update(header_values) values.update(header_values)
values.update(cookie_values) values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params: if dependant.body_params:
( (
body_values, body_values,
body_errors, body_errors,
) = await request_body_to_args( # body_params checked above ) = await request_body_to_args(
body_fields=dependant.body_params, body_fields=dependant.body_params,
received_body=body, received_body=body,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
values.update(body_values) values.update(body_values)
errors.extend(body_errors) errors.extend(body_errors)
if dependant.http_connection_param_name: if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request): if dependant.request_param_name and isinstance(request, Request):
@ -677,8 +788,6 @@ async def solve_dependencies(
elif dependant.websocket_param_name and isinstance(request, WebSocket): elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name: if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name: if dependant.response_param_name:
values[dependant.response_param_name] = response values[dependant.response_param_name] = response
@ -686,6 +795,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes( values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes scopes=dependant.security_scopes
) )
return SolvedDependency( return SolvedDependency(
values=values, values=values,
errors=errors, errors=errors,

2
tests/test_dependency_cache.py

@ -29,7 +29,7 @@ async def get_sub_counter(
@app.get("/sub-counter-no-cache/") @app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache( async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep), subcount: int = Depends(dep_counter),
count: int = Depends(dep_counter, use_cache=False), count: int = Depends(dep_counter, use_cache=False),
): ):
return {"counter": count, "subcounter": subcount} return {"counter": count, "subcounter": subcount}

Loading…
Cancel
Save