|
|
@ -1,4 +1,6 @@ |
|
|
|
import asyncio |
|
|
|
import inspect |
|
|
|
from asyncio import Future |
|
|
|
from contextlib import AsyncExitStack, contextmanager |
|
|
|
from copy import copy, deepcopy |
|
|
|
from dataclasses import dataclass |
|
|
@ -569,6 +571,33 @@ class SolvedDependency: |
|
|
|
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] |
|
|
|
|
|
|
|
|
|
|
|
class DependencySolveException(Exception): |
|
|
|
def __init__(self, errors: List[Any]): |
|
|
|
super().__init__(str(errors)) |
|
|
|
self.errors = errors |
|
|
|
|
|
|
|
|
|
|
|
def is_context_sensitive(dependant: Dependant) -> bool: |
|
|
|
if dependant.call is not None and ( |
|
|
|
is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call) |
|
|
|
): |
|
|
|
return True |
|
|
|
return any(is_context_sensitive(sub) for sub in dependant.dependencies) |
|
|
|
|
|
|
|
|
|
|
|
def is_context_with_background_task(dependant: Dependant) -> bool: |
|
|
|
if dependant.background_tasks_param_name: |
|
|
|
return True |
|
|
|
return any(is_context_with_background_task(sub) for sub in dependant.dependencies) |
|
|
|
|
|
|
|
|
|
|
|
def silence_future_exception(fut: "Future[Any]") -> None: |
|
|
|
try: |
|
|
|
fut.exception() |
|
|
|
except BaseException: |
|
|
|
pass # Silences the warning |
|
|
|
|
|
|
|
|
|
|
|
async def solve_dependencies( |
|
|
|
*, |
|
|
|
request: Union[Request, WebSocket], |
|
|
@ -583,19 +612,25 @@ async def solve_dependencies( |
|
|
|
) -> SolvedDependency: |
|
|
|
values: Dict[str, Any] = {} |
|
|
|
errors: List[Any] = [] |
|
|
|
|
|
|
|
if response is None: |
|
|
|
response = Response() |
|
|
|
del response.headers["content-length"] |
|
|
|
response.status_code = None # type: ignore |
|
|
|
|
|
|
|
dependency_cache = dependency_cache or {} |
|
|
|
sub_dependant: Dependant |
|
|
|
for sub_dependant in dependant.dependencies: |
|
|
|
|
|
|
|
if background_tasks is None and is_context_with_background_task(dependant): |
|
|
|
background_tasks = BackgroundTasks() |
|
|
|
|
|
|
|
async def resolve_sub_dependant( |
|
|
|
sub_dependant: Dependant, |
|
|
|
) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]: |
|
|
|
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) |
|
|
|
sub_dependant.cache_key = cast( |
|
|
|
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key |
|
|
|
) |
|
|
|
cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key) |
|
|
|
call = sub_dependant.call |
|
|
|
use_sub_dependant = sub_dependant |
|
|
|
|
|
|
|
if ( |
|
|
|
dependency_overrides_provider |
|
|
|
and dependency_overrides_provider.dependency_overrides |
|
|
@ -612,6 +647,27 @@ async def solve_dependencies( |
|
|
|
security_scopes=sub_dependant.security_scopes, |
|
|
|
) |
|
|
|
|
|
|
|
def resolve_cached() -> Optional["Future[Any]"]: |
|
|
|
if not sub_dependant.use_cache: |
|
|
|
return None |
|
|
|
|
|
|
|
if cache_key not in dependency_cache: |
|
|
|
future = asyncio.get_event_loop().create_future() |
|
|
|
# Ensures ignored exceptions are not logged with warning, as we only raise the first one. |
|
|
|
future.add_done_callback(silence_future_exception) |
|
|
|
dependency_cache[cache_key] = future |
|
|
|
return None |
|
|
|
|
|
|
|
cached = dependency_cache[cache_key] |
|
|
|
|
|
|
|
if isinstance(cached, Future): |
|
|
|
return cached |
|
|
|
|
|
|
|
future = Future() |
|
|
|
future.set_result(cached) |
|
|
|
return future |
|
|
|
|
|
|
|
async def resolve_value() -> Any: |
|
|
|
solved_result = await solve_dependencies( |
|
|
|
request=request, |
|
|
|
dependant=use_sub_dependant, |
|
|
@ -623,25 +679,78 @@ async def solve_dependencies( |
|
|
|
async_exit_stack=async_exit_stack, |
|
|
|
embed_body_fields=embed_body_fields, |
|
|
|
) |
|
|
|
background_tasks = solved_result.background_tasks |
|
|
|
dependency_cache.update(solved_result.dependency_cache) |
|
|
|
|
|
|
|
if solved_result.errors: |
|
|
|
errors.extend(solved_result.errors) |
|
|
|
continue |
|
|
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: |
|
|
|
solved = dependency_cache[sub_dependant.cache_key] |
|
|
|
elif is_gen_callable(call) or is_async_gen_callable(call): |
|
|
|
solved = await solve_generator( |
|
|
|
call=call, stack=async_exit_stack, sub_values=solved_result.values |
|
|
|
raise DependencySolveException(solved_result.errors) |
|
|
|
|
|
|
|
if is_gen_callable(call) or is_async_gen_callable(call): |
|
|
|
return await solve_generator( |
|
|
|
call=call, |
|
|
|
stack=async_exit_stack, |
|
|
|
sub_values=solved_result.values, |
|
|
|
) |
|
|
|
elif is_coroutine_callable(call): |
|
|
|
solved = await call(**solved_result.values) |
|
|
|
return await call(**solved_result.values) |
|
|
|
else: |
|
|
|
return await run_in_threadpool(call, **solved_result.values) |
|
|
|
|
|
|
|
task = resolve_cached() or resolve_value() |
|
|
|
|
|
|
|
def ensure_cache(value: Optional[Any]) -> None: |
|
|
|
if ( |
|
|
|
cache_key not in dependency_cache |
|
|
|
or not isinstance(dependency_cache[cache_key], Future) |
|
|
|
or dependency_cache[cache_key].done() |
|
|
|
): |
|
|
|
return |
|
|
|
|
|
|
|
if isinstance(value, BaseException): |
|
|
|
dependency_cache[cache_key].set_exception(value) |
|
|
|
else: |
|
|
|
solved = await run_in_threadpool(call, **solved_result.values) |
|
|
|
if sub_dependant.name is not None: |
|
|
|
values[sub_dependant.name] = solved |
|
|
|
if sub_dependant.cache_key not in dependency_cache: |
|
|
|
dependency_cache[sub_dependant.cache_key] = solved |
|
|
|
dependency_cache[cache_key].set_result(value) |
|
|
|
|
|
|
|
try: |
|
|
|
resolved = await task |
|
|
|
except DependencySolveException as exc: |
|
|
|
ensure_cache(None) |
|
|
|
return sub_dependant.name, None, exc.errors, None |
|
|
|
except BaseException as resolution_exc: |
|
|
|
ensure_cache(resolution_exc) |
|
|
|
return sub_dependant.name, None, None, resolution_exc |
|
|
|
|
|
|
|
ensure_cache(resolved) |
|
|
|
return sub_dependant.name, resolved, None, None |
|
|
|
|
|
|
|
sequential_deps = [] |
|
|
|
parallel_deps = [] |
|
|
|
for sub in dependant.dependencies: |
|
|
|
if is_context_sensitive(sub): |
|
|
|
sequential_deps.append(sub) |
|
|
|
else: |
|
|
|
parallel_deps.append(sub) |
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
for sub in sequential_deps: |
|
|
|
result = await resolve_sub_dependant(sub) |
|
|
|
results.append(result) |
|
|
|
|
|
|
|
if parallel_deps: |
|
|
|
parallel_results = await asyncio.gather( |
|
|
|
*[resolve_sub_dependant(sub) for sub in parallel_deps] |
|
|
|
) |
|
|
|
results.extend(parallel_results) |
|
|
|
|
|
|
|
for name, result, sub_errors, sub_exception in results: |
|
|
|
# Ensures order of exception based on dependency order. |
|
|
|
if sub_exception: |
|
|
|
raise sub_exception |
|
|
|
if sub_errors: |
|
|
|
errors.extend(sub_errors) |
|
|
|
continue |
|
|
|
if name is not None: |
|
|
|
values[name] = result |
|
|
|
|
|
|
|
path_values, path_errors = request_params_to_args( |
|
|
|
dependant.path_params, request.path_params |
|
|
|
) |
|
|
@ -659,17 +768,19 @@ async def solve_dependencies( |
|
|
|
values.update(header_values) |
|
|
|
values.update(cookie_values) |
|
|
|
errors += path_errors + query_errors + header_errors + cookie_errors |
|
|
|
|
|
|
|
if dependant.body_params: |
|
|
|
( |
|
|
|
body_values, |
|
|
|
body_errors, |
|
|
|
) = await request_body_to_args( # body_params checked above |
|
|
|
) = await request_body_to_args( |
|
|
|
body_fields=dependant.body_params, |
|
|
|
received_body=body, |
|
|
|
embed_body_fields=embed_body_fields, |
|
|
|
) |
|
|
|
values.update(body_values) |
|
|
|
errors.extend(body_errors) |
|
|
|
|
|
|
|
if dependant.http_connection_param_name: |
|
|
|
values[dependant.http_connection_param_name] = request |
|
|
|
if dependant.request_param_name and isinstance(request, Request): |
|
|
@ -677,8 +788,6 @@ async def solve_dependencies( |
|
|
|
elif dependant.websocket_param_name and isinstance(request, WebSocket): |
|
|
|
values[dependant.websocket_param_name] = request |
|
|
|
if dependant.background_tasks_param_name: |
|
|
|
if background_tasks is None: |
|
|
|
background_tasks = BackgroundTasks() |
|
|
|
values[dependant.background_tasks_param_name] = background_tasks |
|
|
|
if dependant.response_param_name: |
|
|
|
values[dependant.response_param_name] = response |
|
|
@ -686,6 +795,7 @@ async def solve_dependencies( |
|
|
|
values[dependant.security_scopes_param_name] = SecurityScopes( |
|
|
|
scopes=dependant.security_scopes |
|
|
|
) |
|
|
|
|
|
|
|
return SolvedDependency( |
|
|
|
values=values, |
|
|
|
errors=errors, |
|
|
|