Browse Source

add limiter keyword argument to Dependency and Security

pull/11895/head
Alexander 'Leo' Bergolth 1 year ago
parent
commit
50e89daa52
  1. 2
      fastapi/dependencies/models.py
  2. 19
      fastapi/dependencies/utils.py
  3. 32
      fastapi/param_functions.py
  4. 13
      fastapi/params.py

2
fastapi/dependencies/models.py

@ -1,5 +1,6 @@
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple
import anyio
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
@ -32,6 +33,7 @@ class Dependant:
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
limiter: Optional[anyio.CapacityLimiter] = None
def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

19
fastapi/dependencies/utils.py

@ -49,6 +49,7 @@ from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
run_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger
@ -58,7 +59,6 @@ from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import create_model_field, get_path_param_names
from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
@ -149,6 +149,7 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
limiter=depends.limiter,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
@ -176,6 +177,7 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
limiter=dependant.limiter,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
@ -244,6 +246,7 @@ def get_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
@ -254,6 +257,7 @@ def get_dependant(
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
limiter=limiter,
)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
@ -529,10 +533,12 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any],
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values),
limiter=limiter)
elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
@ -587,6 +593,7 @@ async def solve_dependencies(
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
limiter=sub_dependant.limiter,
)
solved_result = await solve_dependencies(
@ -608,12 +615,14 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
call=call, stack=async_exit_stack, sub_values=solved_result.values,
limiter=sub_dependant.limiter,
)
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter,
**solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:

32
fastapi/param_functions.py

@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi import params
from fastapi._compat import Undefined
@ -2244,6 +2245,20 @@ def Depends( # noqa: N802
"""
),
] = True,
limiter: Annotated[
anyio.CapacityLimiter,
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
For async dependencies (defined using `async def`) this parameter is ignored.
"""
),
] = None,
) -> Any:
"""
Declare a FastAPI dependency.
@ -2274,7 +2289,7 @@ def Depends( # noqa: N802
return commons
```
"""
return params.Depends(dependency=dependency, use_cache=use_cache)
return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter)
def Security( # noqa: N802
@ -2321,6 +2336,18 @@ def Security( # noqa: N802
"""
),
] = True,
limiter: Annotated[
anyio.CapacityLimiter,
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
"""
),
] = None,
) -> Any:
"""
Declare a FastAPI Security dependency.
@ -2357,4 +2384,5 @@ def Security( # noqa: N802
return [{"item_id": "Foo", "owner": current_user.username}]
```
"""
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache,
limiter=limiter)

13
fastapi/params.py

@ -1,6 +1,7 @@
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo
@ -760,15 +761,20 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
self, dependency: Optional[Callable[..., Any]] = None, *,
use_cache: bool = True,
limiter: anyio.CapacityLimiter | None = None,
):
self.dependency = dependency
self.use_cache = use_cache
self.limiter = limiter
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
limiter = f", limiter=CapacityLimiter({self.limiter.total_tokens})" \
if self.limiter else ""
return f"{self.__class__.__name__}({attr}{cache}{limiter})"
class Security(Depends):
@ -778,6 +784,7 @@ class Security(Depends):
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
limiter: anyio.CapacityLimiter | None = None,
):
super().__init__(dependency=dependency, use_cache=use_cache)
super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter)
self.scopes = scopes or []

Loading…
Cancel
Save