From 50e89daa5237e4ac5ee07374fd87e98621f5556a Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Fri, 26 Jul 2024 11:23:32 +0200 Subject: [PATCH] add limiter keyword argument to Dependency and Security --- fastapi/dependencies/models.py | 2 ++ fastapi/dependencies/utils.py | 19 ++++++++++++++----- fastapi/param_functions.py | 32 ++++++++++++++++++++++++++++++-- fastapi/params.py | 13 ++++++++++--- 4 files changed, 56 insertions(+), 10 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..0891213e8 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Sequence, Tuple +import anyio from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -32,6 +33,7 @@ class Dependant: use_cache: bool = True path: Optional[str] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) + limiter: Optional[anyio.CapacityLimiter] = None def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0dcba62f1..96a36de1d 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -49,6 +49,7 @@ from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, + run_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger @@ -58,7 +59,6 @@ from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.utils import create_model_field, get_path_param_names from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks -from starlette.concurrency import run_in_threadpool from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import HTTPConnection, Request from starlette.responses import Response @@ -149,6 +149,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + limiter=depends.limiter, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -176,6 +177,7 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), use_cache=dependant.use_cache, + limiter=dependant.limiter, path=dependant.path, ) for sub_dependant in dependant.dependencies: @@ -244,6 +246,7 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + limiter: Optional[anyio.CapacityLimiter] = None, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -254,6 +257,7 @@ def get_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + limiter=limiter, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -529,10 +533,12 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any], + limiter: Optional[anyio.CapacityLimiter] = None, ) -> Any: if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values), + limiter=limiter) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -587,6 +593,7 @@ async def solve_dependencies( call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, + limiter=sub_dependant.limiter, ) solved_result = await solve_dependencies( @@ -608,12 +615,14 @@ async def solve_dependencies( solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + call=call, stack=async_exit_stack, sub_values=solved_result.values, + limiter=sub_dependant.limiter, ) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, **solved_result.values) + solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter, + **solved_result.values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 0d5f27af4..77accd9f6 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import anyio from fastapi import params from fastapi._compat import Undefined @@ -2244,6 +2245,20 @@ def Depends( # noqa: N802 """ ), ] = True, + limiter: Annotated[ + anyio.CapacityLimiter, + Doc( + """ + By default, synchronous dependencies will be run in a threadpool + with the number of concurrent threads limited by the current default anyio + thread limiter. A different `anyio.CapacityLimiter` may be specified + for problematic dependencies to use a different (logical) thread pool with + other limits in order to avoid blocking other threads. + + For async dependencies (defined using `async def`) this parameter is ignored. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI dependency. @@ -2274,7 +2289,7 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter) def Security( # noqa: N802 @@ -2321,6 +2336,18 @@ def Security( # noqa: N802 """ ), ] = True, + limiter: Annotated[ + anyio.CapacityLimiter, + Doc( + """ + By default, synchronous dependencies will be run in a threadpool + with the number of concurrent threads limited by the current default anyio + thread limiter. A different `anyio.CapacityLimiter` may be specified + for problematic dependencies to use a different (logical) thread pool with + other limits in order to avoid blocking other threads. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI Security dependency. @@ -2357,4 +2384,5 @@ def Security( # noqa: N802 return [{"item_id": "Foo", "owner": current_user.username}] ``` """ - return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) + return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache, + limiter=limiter) diff --git a/fastapi/params.py b/fastapi/params.py index cc2a5c13c..499c2d66f 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -1,6 +1,7 @@ import warnings from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import anyio from fastapi.openapi.models import Example from pydantic.fields import FieldInfo @@ -760,15 +761,20 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, dependency: Optional[Callable[..., Any]] = None, *, + use_cache: bool = True, + limiter: anyio.CapacityLimiter | None = None, ): self.dependency = dependency self.use_cache = use_cache + self.limiter = limiter def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" + limiter = f", limiter=CapacityLimiter({self.limiter.total_tokens})" \ + if self.limiter else "" + return f"{self.__class__.__name__}({attr}{cache}{limiter})" class Security(Depends): @@ -778,6 +784,7 @@ class Security(Depends): *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, + limiter: anyio.CapacityLimiter | None = None, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter) self.scopes = scopes or []