Browse Source

Merge d518b23bef into 6e69d62bfe

pull/13756/merge
Fodor Zoltan 2 days ago
committed by GitHub
parent
commit
b42f0f741c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 1
      fastapi/dependencies/models.py
  2. 172
      fastapi/dependencies/utils.py
  3. 12
      fastapi/params.py
  4. 2
      tests/test_dependency_cache.py

1
fastapi/dependencies/models.py

@ -32,6 +32,7 @@ class Dependant:
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
parallelizable: bool = True
def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

172
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import asyncio
import inspect
from asyncio import Future
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
@ -165,6 +167,7 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
parallelizable=depends.parallelizable,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
@ -192,6 +195,7 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
parallelizable=dependant.parallelizable,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
@ -269,6 +273,7 @@ def get_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
parallelizable: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
@ -279,6 +284,7 @@ def get_dependant(
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
parallelizable=parallelizable,
)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
@ -569,6 +575,34 @@ class SolvedDependency:
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
class DependencySolveException(Exception):
def __init__(self, errors: List[Any]):
super().__init__(str(errors))
self.errors = errors
def is_context_sensitive(dependant: Dependant) -> bool:
if dependant.parallelizable is False or (
dependant.call is not None
and (is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call))
):
return True
return any(is_context_sensitive(sub) for sub in dependant.dependencies)
def is_context_with_background_task(dependant: Dependant) -> bool:
if dependant.background_tasks_param_name:
return True
return any(is_context_with_background_task(sub) for sub in dependant.dependencies)
def silence_future_exception(fut: "Future[Any]") -> None:
try:
fut.exception()
except BaseException:
pass # Silences the warning
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
@ -583,19 +617,25 @@ async def solve_dependencies(
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
if response is None:
response = Response()
del response.headers["content-length"]
response.status_code = None # type: ignore
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
if background_tasks is None and is_context_with_background_task(dependant):
background_tasks = BackgroundTasks()
async def resolve_sub_dependant(
sub_dependant: Dependant,
) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
@ -612,6 +652,27 @@ async def solve_dependencies(
security_scopes=sub_dependant.security_scopes,
)
def resolve_cached() -> Optional["Future[Any]"]:
if not sub_dependant.use_cache:
return None
if cache_key not in dependency_cache:
future = asyncio.get_event_loop().create_future()
# Ensures ignored exceptions are not logged with warning, as we only raise the first one.
future.add_done_callback(silence_future_exception)
dependency_cache[cache_key] = future
return None
cached = dependency_cache[cache_key]
if isinstance(cached, Future):
return cached
future = Future()
future.set_result(cached)
return future
async def resolve_value() -> Any:
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
@ -623,25 +684,89 @@ async def solve_dependencies(
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors:
errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
raise DependencySolveException(solved_result.errors)
if is_gen_callable(call) or is_async_gen_callable(call):
return await solve_generator(
call=call,
stack=async_exit_stack,
sub_values=solved_result.values,
)
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
return await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
return await run_in_threadpool(call, **solved_result.values)
task = resolve_cached() or resolve_value()
def ensure_cache(value: Optional[Any]) -> None:
if (
cache_key not in dependency_cache
or not isinstance(dependency_cache[cache_key], Future)
or dependency_cache[cache_key].done()
):
return
if isinstance(value, BaseException):
dependency_cache[cache_key].set_exception(value)
else:
dependency_cache[cache_key].set_result(value)
try:
resolved = await task
except DependencySolveException as exc:
ensure_cache(None)
return sub_dependant.name, None, exc.errors, None
except BaseException as resolution_exc:
ensure_cache(resolution_exc)
return sub_dependant.name, None, None, resolution_exc
ensure_cache(resolved)
return sub_dependant.name, resolved, None, None
def unpack_results(
results: list[
Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]
],
) -> None:
for name, value, sub_errors, sub_exception in results:
# Ensures order of exception based on dependency order.
if sub_exception:
raise sub_exception
if sub_errors:
errors.extend(sub_errors)
continue
if name is not None:
values[name] = value
sequential_deps = []
parallel_deps = []
for sub in dependant.dependencies:
if is_context_sensitive(sub):
sequential_deps.append(sub)
else:
parallel_deps.append(sub)
sequential_results: list[
Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]
] = []
for sub in sequential_deps:
s_result = await resolve_sub_dependant(sub)
sequential_results.append(s_result)
unpack_results(sequential_results)
parallel_results: list[
Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]
] = []
if parallel_deps:
p_result = await asyncio.gather(
*[resolve_sub_dependant(sub) for sub in parallel_deps]
)
parallel_results.extend(p_result)
unpack_results(parallel_results)
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
@ -659,17 +784,19 @@ async def solve_dependencies(
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
(
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
) = await request_body_to_args(
body_fields=dependant.body_params,
received_body=body,
embed_body_fields=embed_body_fields,
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
@ -677,8 +804,6 @@ async def solve_dependencies(
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
@ -686,6 +811,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return SolvedDependency(
values=values,
errors=errors,

12
fastapi/params.py

@ -763,10 +763,15 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
parallelizable: bool = True,
):
self.dependency = dependency
self.use_cache = use_cache
self.parallelizable = parallelizable
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
@ -781,6 +786,9 @@ class Security(Depends):
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
parallelizable: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
super().__init__(
dependency=dependency, use_cache=use_cache, parallelizable=parallelizable
)
self.scopes = scopes or []

2
tests/test_dependency_cache.py

@ -29,7 +29,7 @@ async def get_sub_counter(
@app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep),
subcount: int = Depends(dep_counter),
count: int = Depends(dep_counter, use_cache=False),
):
return {"counter": count, "subcounter": subcount}

Loading…
Cancel
Save