Browse Source

add limiter keyword argument to Dependency and Security

pull/11895/head
Alexander 'Leo' Bergolth 1 year ago
parent
commit
50e89daa52
  1. 2
      fastapi/dependencies/models.py
  2. 19
      fastapi/dependencies/utils.py
  3. 32
      fastapi/param_functions.py
  4. 13
      fastapi/params.py

2
fastapi/dependencies/models.py

@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple from typing import Any, Callable, List, Optional, Sequence, Tuple
import anyio
from fastapi._compat import ModelField from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
@ -32,6 +33,7 @@ class Dependant:
use_cache: bool = True use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
limiter: Optional[anyio.CapacityLimiter] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

19
fastapi/dependencies/utils.py

@ -49,6 +49,7 @@ from fastapi.background import BackgroundTasks
from fastapi.concurrency import ( from fastapi.concurrency import (
asynccontextmanager, asynccontextmanager,
contextmanager_in_threadpool, contextmanager_in_threadpool,
run_in_threadpool,
) )
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger from fastapi.logger import logger
@ -58,7 +59,6 @@ from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import create_model_field, get_path_param_names from fastapi.utils import create_model_field, get_path_param_names
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import HTTPConnection, Request from starlette.requests import HTTPConnection, Request
from starlette.responses import Response from starlette.responses import Response
@ -149,6 +149,7 @@ def get_sub_dependant(
name=name, name=name,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=depends.use_cache, use_cache=depends.use_cache,
limiter=depends.limiter,
) )
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
@ -176,6 +177,7 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(), body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(), security_requirements=dependant.security_requirements.copy(),
use_cache=dependant.use_cache, use_cache=dependant.use_cache,
limiter=dependant.limiter,
path=dependant.path, path=dependant.path,
) )
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
@ -244,6 +246,7 @@ def get_dependant(
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Dependant: ) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
@ -254,6 +257,7 @@ def get_dependant(
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=use_cache, use_cache=use_cache,
limiter=limiter,
) )
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names is_path_param = param_name in path_param_names
@ -529,10 +533,12 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator( async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any],
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Any: ) -> Any:
if is_gen_callable(call): if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values),
limiter=limiter)
elif is_async_gen_callable(call): elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values) cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm) return await stack.enter_async_context(cm)
@ -587,6 +593,7 @@ async def solve_dependencies(
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
limiter=sub_dependant.limiter,
) )
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
@ -608,12 +615,14 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependant.cache_key] solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator( solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values call=call, stack=async_exit_stack, sub_values=solved_result.values,
limiter=sub_dependant.limiter,
) )
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
solved = await call(**solved_result.values) solved = await call(**solved_result.values)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter,
**solved_result.values)
if sub_dependant.name is not None: if sub_dependant.name is not None:
values[sub_dependant.name] = solved values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache: if sub_dependant.cache_key not in dependency_cache:

32
fastapi/param_functions.py

@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi import params from fastapi import params
from fastapi._compat import Undefined from fastapi._compat import Undefined
@ -2244,6 +2245,20 @@ def Depends( # noqa: N802
""" """
), ),
] = True, ] = True,
limiter: Annotated[
anyio.CapacityLimiter,
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
For async dependencies (defined using `async def`) this parameter is ignored.
"""
),
] = None,
) -> Any: ) -> Any:
""" """
Declare a FastAPI dependency. Declare a FastAPI dependency.
@ -2274,7 +2289,7 @@ def Depends( # noqa: N802
return commons return commons
``` ```
""" """
return params.Depends(dependency=dependency, use_cache=use_cache) return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter)
def Security( # noqa: N802 def Security( # noqa: N802
@ -2321,6 +2336,18 @@ def Security( # noqa: N802
""" """
), ),
] = True, ] = True,
limiter: Annotated[
anyio.CapacityLimiter,
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
"""
),
] = None,
) -> Any: ) -> Any:
""" """
Declare a FastAPI Security dependency. Declare a FastAPI Security dependency.
@ -2357,4 +2384,5 @@ def Security( # noqa: N802
return [{"item_id": "Foo", "owner": current_user.username}] return [{"item_id": "Foo", "owner": current_user.username}]
``` ```
""" """
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache,
limiter=limiter)

13
fastapi/params.py

@ -1,6 +1,7 @@
import warnings import warnings
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -760,15 +761,20 @@ class File(Form):
class Depends: class Depends:
def __init__( def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True self, dependency: Optional[Callable[..., Any]] = None, *,
use_cache: bool = True,
limiter: anyio.CapacityLimiter | None = None,
): ):
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
self.limiter = limiter
def __repr__(self) -> str: def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False" cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})" limiter = f", limiter=CapacityLimiter({self.limiter.total_tokens})" \
if self.limiter else ""
return f"{self.__class__.__name__}({attr}{cache}{limiter})"
class Security(Depends): class Security(Depends):
@ -778,6 +784,7 @@ class Security(Depends):
*, *,
scopes: Optional[Sequence[str]] = None, scopes: Optional[Sequence[str]] = None,
use_cache: bool = True, use_cache: bool = True,
limiter: anyio.CapacityLimiter | None = None,
): ):
super().__init__(dependency=dependency, use_cache=use_cache) super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter)
self.scopes = scopes or [] self.scopes = scopes or []

Loading…
Cancel
Save