Browse Source

Merge d518b23bef into 6e69d62bfe

pull/13756/merge
Fodor Zoltan 2 days ago
committed by GitHub
parent
commit
b42f0f741c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 1
      fastapi/dependencies/models.py
  2. 198
      fastapi/dependencies/utils.py
  3. 12
      fastapi/params.py
  4. 2
      tests/test_dependency_cache.py

1
fastapi/dependencies/models.py

@ -32,6 +32,7 @@ class Dependant:
use_cache: bool = True use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
parallelizable: bool = True
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

198
fastapi/dependencies/utils.py

@ -1,4 +1,6 @@
import asyncio
import inspect import inspect
from asyncio import Future
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -165,6 +167,7 @@ def get_sub_dependant(
name=name, name=name,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=depends.use_cache, use_cache=depends.use_cache,
parallelizable=depends.parallelizable,
) )
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
@ -192,6 +195,7 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(), body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(), security_requirements=dependant.security_requirements.copy(),
use_cache=dependant.use_cache, use_cache=dependant.use_cache,
parallelizable=dependant.parallelizable,
path=dependant.path, path=dependant.path,
) )
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
@ -269,6 +273,7 @@ def get_dependant(
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
parallelizable: bool = True,
) -> Dependant: ) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
@ -279,6 +284,7 @@ def get_dependant(
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=use_cache, use_cache=use_cache,
parallelizable=parallelizable,
) )
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names is_path_param = param_name in path_param_names
@ -569,6 +575,34 @@ class SolvedDependency:
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
class DependencySolveException(Exception):
def __init__(self, errors: List[Any]):
super().__init__(str(errors))
self.errors = errors
def is_context_sensitive(dependant: Dependant) -> bool:
if dependant.parallelizable is False or (
dependant.call is not None
and (is_gen_callable(dependant.call) or is_async_gen_callable(dependant.call))
):
return True
return any(is_context_sensitive(sub) for sub in dependant.dependencies)
def is_context_with_background_task(dependant: Dependant) -> bool:
if dependant.background_tasks_param_name:
return True
return any(is_context_with_background_task(sub) for sub in dependant.dependencies)
def silence_future_exception(fut: "Future[Any]") -> None:
try:
fut.exception()
except BaseException:
pass # Silences the warning
async def solve_dependencies( async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
@ -583,19 +617,25 @@ async def solve_dependencies(
) -> SolvedDependency: ) -> SolvedDependency:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[Any] = []
if response is None: if response is None:
response = Response() response = Response()
del response.headers["content-length"] del response.headers["content-length"]
response.status_code = None # type: ignore response.status_code = None # type: ignore
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies: if background_tasks is None and is_context_with_background_task(dependant):
background_tasks = BackgroundTasks()
async def resolve_sub_dependant(
sub_dependant: Dependant,
) -> Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast( cache_key = cast(Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key)
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if ( if (
dependency_overrides_provider dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides and dependency_overrides_provider.dependency_overrides
@ -612,36 +652,121 @@ async def solve_dependencies(
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
) )
solved_result = await solve_dependencies( def resolve_cached() -> Optional["Future[Any]"]:
request=request, if not sub_dependant.use_cache:
dependant=use_sub_dependant, return None
body=body,
background_tasks=background_tasks, if cache_key not in dependency_cache:
response=response, future = asyncio.get_event_loop().create_future()
dependency_overrides_provider=dependency_overrides_provider, # Ensures ignored exceptions are not logged with warning, as we only raise the first one.
dependency_cache=dependency_cache, future.add_done_callback(silence_future_exception)
async_exit_stack=async_exit_stack, dependency_cache[cache_key] = future
embed_body_fields=embed_body_fields, return None
)
background_tasks = solved_result.background_tasks cached = dependency_cache[cache_key]
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors: if isinstance(cached, Future):
errors.extend(solved_result.errors) return cached
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: future = Future()
solved = dependency_cache[sub_dependant.cache_key] future.set_result(cached)
elif is_gen_callable(call) or is_async_gen_callable(call): return future
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values async def resolve_value() -> Any:
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
) )
elif is_coroutine_callable(call):
solved = await call(**solved_result.values) if solved_result.errors:
raise DependencySolveException(solved_result.errors)
if is_gen_callable(call) or is_async_gen_callable(call):
return await solve_generator(
call=call,
stack=async_exit_stack,
sub_values=solved_result.values,
)
elif is_coroutine_callable(call):
return await call(**solved_result.values)
else:
return await run_in_threadpool(call, **solved_result.values)
task = resolve_cached() or resolve_value()
def ensure_cache(value: Optional[Any]) -> None:
if (
cache_key not in dependency_cache
or not isinstance(dependency_cache[cache_key], Future)
or dependency_cache[cache_key].done()
):
return
if isinstance(value, BaseException):
dependency_cache[cache_key].set_exception(value)
else:
dependency_cache[cache_key].set_result(value)
try:
resolved = await task
except DependencySolveException as exc:
ensure_cache(None)
return sub_dependant.name, None, exc.errors, None
except BaseException as resolution_exc:
ensure_cache(resolution_exc)
return sub_dependant.name, None, None, resolution_exc
ensure_cache(resolved)
return sub_dependant.name, resolved, None, None
def unpack_results(
results: list[
Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]
],
) -> None:
for name, value, sub_errors, sub_exception in results:
# Ensures order of exception based on dependency order.
if sub_exception:
raise sub_exception
if sub_errors:
errors.extend(sub_errors)
continue
if name is not None:
values[name] = value
sequential_deps = []
parallel_deps = []
for sub in dependant.dependencies:
if is_context_sensitive(sub):
sequential_deps.append(sub)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) parallel_deps.append(sub)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved sequential_results: list[
if sub_dependant.cache_key not in dependency_cache: Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]
dependency_cache[sub_dependant.cache_key] = solved ] = []
for sub in sequential_deps:
s_result = await resolve_sub_dependant(sub)
sequential_results.append(s_result)
unpack_results(sequential_results)
parallel_results: list[
Tuple[Optional[str], Any, Optional[List[Any]], Optional[BaseException]]
] = []
if parallel_deps:
p_result = await asyncio.gather(
*[resolve_sub_dependant(sub) for sub in parallel_deps]
)
parallel_results.extend(p_result)
unpack_results(parallel_results)
path_values, path_errors = request_params_to_args( path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params dependant.path_params, request.path_params
) )
@ -659,17 +784,19 @@ async def solve_dependencies(
values.update(header_values) values.update(header_values)
values.update(cookie_values) values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params: if dependant.body_params:
( (
body_values, body_values,
body_errors, body_errors,
) = await request_body_to_args( # body_params checked above ) = await request_body_to_args(
body_fields=dependant.body_params, body_fields=dependant.body_params,
received_body=body, received_body=body,
embed_body_fields=embed_body_fields, embed_body_fields=embed_body_fields,
) )
values.update(body_values) values.update(body_values)
errors.extend(body_errors) errors.extend(body_errors)
if dependant.http_connection_param_name: if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request): if dependant.request_param_name and isinstance(request, Request):
@ -677,8 +804,6 @@ async def solve_dependencies(
elif dependant.websocket_param_name and isinstance(request, WebSocket): elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name: if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name: if dependant.response_param_name:
values[dependant.response_param_name] = response values[dependant.response_param_name] = response
@ -686,6 +811,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes( values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes scopes=dependant.security_scopes
) )
return SolvedDependency( return SolvedDependency(
values=values, values=values,
errors=errors, errors=errors,

12
fastapi/params.py

@ -763,10 +763,15 @@ class File(Form):
class Depends: class Depends:
def __init__( def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
parallelizable: bool = True,
): ):
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
self.parallelizable = parallelizable
def __repr__(self) -> str: def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
@ -781,6 +786,9 @@ class Security(Depends):
*, *,
scopes: Optional[Sequence[str]] = None, scopes: Optional[Sequence[str]] = None,
use_cache: bool = True, use_cache: bool = True,
parallelizable: bool = True,
): ):
super().__init__(dependency=dependency, use_cache=use_cache) super().__init__(
dependency=dependency, use_cache=use_cache, parallelizable=parallelizable
)
self.scopes = scopes or [] self.scopes = scopes or []

2
tests/test_dependency_cache.py

@ -29,7 +29,7 @@ async def get_sub_counter(
@app.get("/sub-counter-no-cache/") @app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache( async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep), subcount: int = Depends(dep_counter),
count: int = Depends(dep_counter, use_cache=False), count: int = Depends(dep_counter, use_cache=False),
): ):
return {"counter": count, "subcounter": subcount} return {"counter": count, "subcounter": subcount}

Loading…
Cancel
Save