diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 3202c7078..26263beae 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,20 +1,41 @@ +import functools +import sys +import typing from contextlib import asynccontextmanager as asynccontextmanager -from typing import AsyncGenerator, ContextManager, TypeVar +from typing import AsyncGenerator, ContextManager, Optional, TypeVar + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec import anyio.to_thread from anyio import CapacityLimiter from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa -from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa from starlette.concurrency import ( # noqa run_until_first_complete as run_until_first_complete, ) +_P = ParamSpec("_P") _T = TypeVar("_T") +async def run_in_threadpool( + func: typing.Callable[_P, _T], + *args: typing.Any, + _limiter: Optional[anyio.CapacityLimiter] = None, + **kwargs: typing.Any, +) -> _T: + if kwargs: # pragma: no cover + # run_sync doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await anyio.to_thread.run_sync(func, *args, limiter=_limiter) + + @asynccontextmanager async def contextmanager_in_threadpool( cm: ContextManager[_T], + limiter: Optional[anyio.CapacityLimiter] = None, ) -> AsyncGenerator[_T, None]: # blocking __exit__ from running waiting on a free thread # can create race conditions/deadlocks if the context manager itself @@ -24,16 +45,14 @@ async def contextmanager_in_threadpool( # works (1 is arbitrary) exit_limiter = CapacityLimiter(1) try: - yield await run_in_threadpool(cm.__enter__) + yield await run_in_threadpool(cm.__enter__, _limiter=limiter) except Exception as e: ok = bool( - await anyio.to_thread.run_sync( - cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter + await run_in_threadpool( + cm.__exit__, type(e), e, e.__traceback__, _limiter=exit_limiter ) ) if not ok: raise e else: - await anyio.to_thread.run_sync( - cm.__exit__, None, None, None, limiter=exit_limiter - ) + await run_in_threadpool(cm.__exit__, None, None, None, _limiter=exit_limiter) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..a152cf212 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Sequence, Tuple +import anyio from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -32,6 +33,7 @@ class Dependant: use_cache: bool = True path: Optional[str] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) + limiter: Optional[anyio.CapacityLimiter] = None def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 081b63a8b..a559cc3f5 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -50,6 +50,7 @@ from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, + run_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger @@ -60,7 +61,6 @@ from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks -from starlette.concurrency import run_in_threadpool from starlette.datastructures import ( FormData, Headers, @@ -165,6 +165,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + limiter=depends.limiter, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -192,6 +193,7 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), use_cache=dependant.use_cache, + limiter=dependant.limiter, path=dependant.path, ) for sub_dependant in dependant.dependencies: @@ -269,6 +271,7 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + limiter: Optional[anyio.CapacityLimiter] = None, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -279,6 +282,7 @@ def get_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + limiter=limiter, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -551,10 +555,16 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] + *, + call: Callable[..., Any], + stack: AsyncExitStack, + sub_values: Dict[str, Any], + limiter: Optional[anyio.CapacityLimiter] = None, ) -> Any: if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + cm = contextmanager_in_threadpool( + contextmanager(call)(**sub_values), limiter=limiter + ) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -610,6 +620,7 @@ async def solve_dependencies( call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, + limiter=sub_dependant.limiter, ) solved_result = await solve_dependencies( @@ -632,12 +643,17 @@ async def solve_dependencies( solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + call=call, + stack=async_exit_stack, + sub_values=solved_result.values, + limiter=sub_dependant.limiter, ) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, **solved_result.values) + solved = await run_in_threadpool( + call, _limiter=sub_dependant.limiter, **solved_result.values + ) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index b3621626c..948a2a0af 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import anyio from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example @@ -2244,6 +2245,20 @@ def Depends( # noqa: N802 """ ), ] = True, + limiter: Annotated[ + Optional[anyio.CapacityLimiter], + Doc( + """ + By default, synchronous dependencies will be run in a threadpool + with the number of concurrent threads limited by the current default anyio + thread limiter. A different `anyio.CapacityLimiter` may be specified + for problematic dependencies to use a different (logical) thread pool with + other limits in order to avoid blocking other threads. + + For async dependencies (defined using `async def`) this parameter is ignored. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI dependency. @@ -2274,7 +2289,7 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter) def Security( # noqa: N802 @@ -2321,6 +2336,18 @@ def Security( # noqa: N802 """ ), ] = True, + limiter: Annotated[ + Optional[anyio.CapacityLimiter], + Doc( + """ + By default, synchronous dependencies will be run in a threadpool + with the number of concurrent threads limited by the current default anyio + thread limiter. A different `anyio.CapacityLimiter` may be specified + for problematic dependencies to use a different (logical) thread pool with + other limits in order to avoid blocking other threads. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI Security dependency. @@ -2357,4 +2384,6 @@ def Security( # noqa: N802 return [{"item_id": "Foo", "owner": current_user.username}] ``` """ - return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) + return params.Security( + dependency=dependency, scopes=scopes, use_cache=use_cache, limiter=limiter + ) diff --git a/fastapi/params.py b/fastapi/params.py index 8f5601dd3..9b1f492a3 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -2,6 +2,7 @@ import warnings from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import anyio from fastapi.openapi.models import Example from pydantic.fields import FieldInfo from typing_extensions import Annotated, deprecated @@ -763,15 +764,25 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + limiter: Optional[anyio.CapacityLimiter] = None, ): self.dependency = dependency self.use_cache = use_cache + self.limiter = limiter def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" + limiter = ( + f", limiter=CapacityLimiter({self.limiter.total_tokens})" + if self.limiter + else "" + ) + return f"{self.__class__.__name__}({attr}{cache}{limiter})" class Security(Depends): @@ -781,6 +792,7 @@ class Security(Depends): *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, + limiter: Optional[anyio.CapacityLimiter] = None, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter) self.scopes = scopes or [] diff --git a/tests/test_depends_limiter.py b/tests/test_depends_limiter.py new file mode 100644 index 000000000..1da219939 --- /dev/null +++ b/tests/test_depends_limiter.py @@ -0,0 +1,130 @@ +from contextlib import asynccontextmanager + +import anyio +import sniffio +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +def get_borrowed_tokens(): + return {lname: lobj.borrowed_tokens for lname, lobj in sorted(limiters.items())} + + +def limited_dep(): + # run this in the event loop thread: + tokens = anyio.from_thread.run_sync(get_borrowed_tokens) + yield tokens + + +def init_limiters(): + return { + "default": None, # should be set in Lifespan handler + "a": anyio.CapacityLimiter(5), + "b": anyio.CapacityLimiter(3), + } + + +# Note: +# initializing CapacityLimiters at module level before the event loop has started +# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0 +# see https://github.com/agronholm/anyio/pull/651 + +# The following is a temporary workaround for anyio < 4.2.0: +try: + limiters = init_limiters() +except sniffio.AsyncLibraryNotFoundError: + # not in an async context yet + async def _init_limiters(): + return init_limiters() + + limiters = anyio.run(_init_limiters) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + limiters["default"] = anyio.to_thread.current_default_thread_limiter() + yield { + "limiters": limiters, + } + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/") +async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]): + return borrowed_tokens + + +@app.get("/a") +async def a( + borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["a"])], +): + return borrowed_tokens + + +@app.get("/b") +async def b( + borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["b"])], +): + return borrowed_tokens + + +def test_depends_limiter(): + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"a": 0, "b": 0, "default": 1} + + response = client.get("/a") + assert response.status_code == 200, response.text + assert response.json() == {"a": 1, "b": 0, "default": 0} + + response = client.get("/b") + assert response.status_code == 200, response.text + assert response.json() == {"a": 0, "b": 1, "default": 0} + + +def test_openapi_schema(): + with TestClient(app) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()["paths"] == { + "/": { + "get": { + "summary": "Root", + "operationId": "root__get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } + }, + "/a": { + "get": { + "summary": "A", + "operationId": "a_a_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } + }, + "/b": { + "get": { + "summary": "B", + "operationId": "b_b_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } + }, + }