From e14a92ab93740bf6a92109d49328872a383b8b0a Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Thu, 25 Jul 2024 15:12:53 +0200 Subject: [PATCH 01/10] temporarily add a local copy of starlettes run_in_threadpool and add a _limiter keyword argument (maybe this should be done upstream?) --- fastapi/concurrency.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 894bd3ed1..aa429b617 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,5 +1,12 @@ from contextlib import asynccontextmanager as asynccontextmanager from typing import AsyncGenerator, ContextManager, TypeVar +import functools +import sys +import typing +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec import anyio from anyio import CapacityLimiter @@ -9,8 +16,18 @@ from starlette.concurrency import ( # noqa run_until_first_complete as run_until_first_complete, ) +_P = ParamSpec("_P") _T = TypeVar("_T") +async def run_in_threadpool( + func: typing.Callable[_P, _T], *args: _P.args, + _limiter: anyio.CapacityLimiter | None = None, + **kwargs: _P.kwargs +) -> _T: + if kwargs: # pragma: no cover + # run_sync doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await anyio.to_thread.run_sync(func, *args, limiter=_limiter) @asynccontextmanager async def contextmanager_in_threadpool( From 100b0f9507cc3638e85efc3cef101a328c1a7577 Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Fri, 26 Jul 2024 10:42:33 +0200 Subject: [PATCH 02/10] use Optional type hint to be compatible with python < 3.10 add limiter option to contextmanager_in_threadpool() --- fastapi/concurrency.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index aa429b617..25aa705f4 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager as asynccontextmanager -from typing import AsyncGenerator, ContextManager, TypeVar +from typing import AsyncGenerator, ContextManager, TypeVar, Optional import functools import sys import typing @@ -21,7 +21,7 @@ _T = TypeVar("_T") async def run_in_threadpool( func: typing.Callable[_P, _T], *args: _P.args, - _limiter: anyio.CapacityLimiter | None = None, + _limiter: Optional[anyio.CapacityLimiter] = None, **kwargs: _P.kwargs ) -> _T: if kwargs: # pragma: no cover @@ -31,7 +31,7 @@ async def run_in_threadpool( @asynccontextmanager async def contextmanager_in_threadpool( - cm: ContextManager[_T], + cm: ContextManager[_T], limiter: Optional[anyio.CapacityLimiter] = None, ) -> AsyncGenerator[_T, None]: # blocking __exit__ from running waiting on a free thread # can create race conditions/deadlocks if the context manager itself @@ -41,16 +41,16 @@ async def contextmanager_in_threadpool( # works (1 is arbitrary) exit_limiter = CapacityLimiter(1) try: - yield await run_in_threadpool(cm.__enter__) + yield await run_in_threadpool(cm.__enter__, _limiter=limiter) except Exception as e: ok = bool( - await anyio.to_thread.run_sync( - cm.__exit__, type(e), e, None, limiter=exit_limiter + await run_in_threadpool( + cm.__exit__, type(e), e, None, _limiter=exit_limiter ) ) if not ok: raise e else: - await anyio.to_thread.run_sync( - cm.__exit__, None, None, None, limiter=exit_limiter + await run_in_threadpool( + cm.__exit__, None, None, None, _limiter=exit_limiter ) From 50e89daa5237e4ac5ee07374fd87e98621f5556a Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Fri, 26 Jul 2024 11:23:32 +0200 Subject: [PATCH 03/10] add limiter keyword argument to Dependency and Security --- fastapi/dependencies/models.py | 2 ++ fastapi/dependencies/utils.py | 19 ++++++++++++++----- fastapi/param_functions.py | 32 ++++++++++++++++++++++++++++++-- fastapi/params.py | 13 ++++++++++--- 4 files changed, 56 insertions(+), 10 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..0891213e8 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Sequence, Tuple +import anyio from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -32,6 +33,7 @@ class Dependant: use_cache: bool = True path: Optional[str] = None cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) + limiter: Optional[anyio.CapacityLimiter] = None def __post_init__(self) -> None: self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0dcba62f1..96a36de1d 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -49,6 +49,7 @@ from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, + run_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger @@ -58,7 +59,6 @@ from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.utils import create_model_field, get_path_param_names from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks -from starlette.concurrency import run_in_threadpool from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import HTTPConnection, Request from starlette.responses import Response @@ -149,6 +149,7 @@ def get_sub_dependant( name=name, security_scopes=security_scopes, use_cache=depends.use_cache, + limiter=depends.limiter, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) @@ -176,6 +177,7 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), security_requirements=dependant.security_requirements.copy(), use_cache=dependant.use_cache, + limiter=dependant.limiter, path=dependant.path, ) for sub_dependant in dependant.dependencies: @@ -244,6 +246,7 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + limiter: Optional[anyio.CapacityLimiter] = None, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -254,6 +257,7 @@ def get_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + limiter=limiter, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names @@ -529,10 +533,12 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] + *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any], + limiter: Optional[anyio.CapacityLimiter] = None, ) -> Any: if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) + cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values), + limiter=limiter) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -587,6 +593,7 @@ async def solve_dependencies( call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, + limiter=sub_dependant.limiter, ) solved_result = await solve_dependencies( @@ -608,12 +615,14 @@ async def solve_dependencies( solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + call=call, stack=async_exit_stack, sub_values=solved_result.values, + limiter=sub_dependant.limiter, ) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, **solved_result.values) + solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter, + **solved_result.values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 0d5f27af4..77accd9f6 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import anyio from fastapi import params from fastapi._compat import Undefined @@ -2244,6 +2245,20 @@ def Depends( # noqa: N802 """ ), ] = True, + limiter: Annotated[ + anyio.CapacityLimiter, + Doc( + """ + By default, synchronous dependencies will be run in a threadpool + with the number of concurrent threads limited by the current default anyio + thread limiter. A different `anyio.CapacityLimiter` may be specified + for problematic dependencies to use a different (logical) thread pool with + other limits in order to avoid blocking other threads. + + For async dependencies (defined using `async def`) this parameter is ignored. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI dependency. @@ -2274,7 +2289,7 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter) def Security( # noqa: N802 @@ -2321,6 +2336,18 @@ def Security( # noqa: N802 """ ), ] = True, + limiter: Annotated[ + anyio.CapacityLimiter, + Doc( + """ + By default, synchronous dependencies will be run in a threadpool + with the number of concurrent threads limited by the current default anyio + thread limiter. A different `anyio.CapacityLimiter` may be specified + for problematic dependencies to use a different (logical) thread pool with + other limits in order to avoid blocking other threads. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI Security dependency. @@ -2357,4 +2384,5 @@ def Security( # noqa: N802 return [{"item_id": "Foo", "owner": current_user.username}] ``` """ - return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) + return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache, + limiter=limiter) diff --git a/fastapi/params.py b/fastapi/params.py index cc2a5c13c..499c2d66f 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -1,6 +1,7 @@ import warnings from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import anyio from fastapi.openapi.models import Example from pydantic.fields import FieldInfo @@ -760,15 +761,20 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + self, dependency: Optional[Callable[..., Any]] = None, *, + use_cache: bool = True, + limiter: anyio.CapacityLimiter | None = None, ): self.dependency = dependency self.use_cache = use_cache + self.limiter = limiter def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" + limiter = f", limiter=CapacityLimiter({self.limiter.total_tokens})" \ + if self.limiter else "" + return f"{self.__class__.__name__}({attr}{cache}{limiter})" class Security(Depends): @@ -778,6 +784,7 @@ class Security(Depends): *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, + limiter: anyio.CapacityLimiter | None = None, ): - super().__init__(dependency=dependency, use_cache=use_cache) + super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter) self.scopes = scopes or [] From 3d47a2c5e7beecfedc22e64a8cc14e125913bf40 Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Fri, 26 Jul 2024 14:05:15 +0200 Subject: [PATCH 04/10] try to satisfy linter --- fastapi/concurrency.py | 1 - fastapi/param_functions.py | 4 ++-- fastapi/params.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 25aa705f4..0a7aebc99 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -11,7 +11,6 @@ else: # pragma: no cover import anyio from anyio import CapacityLimiter from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa -from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa from starlette.concurrency import ( # noqa run_until_first_complete as run_until_first_complete, ) diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 77accd9f6..86c3c22e8 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -2246,7 +2246,7 @@ def Depends( # noqa: N802 ), ] = True, limiter: Annotated[ - anyio.CapacityLimiter, + Optional[anyio.CapacityLimiter], Doc( """ By default, synchronous dependencies will be run in a threadpool @@ -2337,7 +2337,7 @@ def Security( # noqa: N802 ), ] = True, limiter: Annotated[ - anyio.CapacityLimiter, + Optional[anyio.CapacityLimiter], Doc( """ By default, synchronous dependencies will be run in a threadpool diff --git a/fastapi/params.py b/fastapi/params.py index 499c2d66f..41494044f 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -763,7 +763,7 @@ class Depends: def __init__( self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True, - limiter: anyio.CapacityLimiter | None = None, + limiter: Optional[anyio.CapacityLimiter] = None, ): self.dependency = dependency self.use_cache = use_cache @@ -784,7 +784,7 @@ class Security(Depends): *, scopes: Optional[Sequence[str]] = None, use_cache: bool = True, - limiter: anyio.CapacityLimiter | None = None, + limiter: Optional[anyio.CapacityLimiter] = None, ): super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter) self.scopes = scopes or [] From d4c0831eb762ab075ee2702a8150d57770edeaf5 Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Tue, 13 Aug 2024 10:55:20 +0200 Subject: [PATCH 05/10] add test --- tests/test_depends_limiter.py | 125 ++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 tests/test_depends_limiter.py diff --git a/tests/test_depends_limiter.py b/tests/test_depends_limiter.py new file mode 100644 index 000000000..f7e015b7e --- /dev/null +++ b/tests/test_depends_limiter.py @@ -0,0 +1,125 @@ +from contextlib import asynccontextmanager +import threading + +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from typing_extensions import Annotated +import anyio +import sniffio + + +def get_borrowed_tokens(): + return { + lname: l.borrowed_tokens + for lname, l in sorted(limiters.items()) + } + +def limited_dep(): + # run this in the event loop thread: + tokens = anyio.from_thread.run_sync(get_borrowed_tokens) + yield tokens + + +def init_limiters(): + return { + 'default': None, # should be set in Lifespan handler + 'a': anyio.CapacityLimiter(5), + 'b': anyio.CapacityLimiter(3), + } + +# Note: +# initializing CapacityLimiters at module level before the event loop has started +# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0 +# see https://github.com/agronholm/anyio/pull/651 + +# The following is a temporary workaround for anyio < 4.2.0: +try: + limiters = init_limiters() +except sniffio.AsyncLibraryNotFoundError: + # not in an async context yet + async def _init_limiters(): + return init_limiters() + limiters = anyio.run(_init_limiters) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + limiters['default'] = anyio.to_thread.current_default_thread_limiter() + yield { + 'limiters': limiters, + } + +app = FastAPI(lifespan=lifespan) + + +@app.get("/") +async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]): + return borrowed_tokens + +@app.get("/a") +async def a(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['a'])]): + return borrowed_tokens + +@app.get("/b") +async def b(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['b'])]): + return borrowed_tokens + + + +def test_depends_limiter(): + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"a":0,"b":0,"default":1} + + response = client.get("/a") + assert response.status_code == 200, response.text + assert response.json() == {"a":1,"b":0,"default":0} + + response = client.get("/b") + assert response.status_code == 200, response.text + assert response.json() == {"a":0,"b":1,"default":0} + + +def test_openapi_schema(): + with TestClient(app) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json()['paths'] == { + '/': { + 'get': { + 'summary': 'Root', + 'operationId': 'root__get', + 'responses': { + '200': { + 'description': 'Successful Response', + 'content': {'application/json': {'schema': {}}} + } + } + } + }, + '/a': { + 'get': { + 'summary': 'A', + 'operationId': 'a_a_get', + 'responses': { + '200': { + 'description': 'Successful Response', + 'content': {'application/json': {'schema': {}}} + } + } + } + }, + '/b': { + 'get': { + 'summary': 'B', + 'operationId': 'b_b_get', + 'responses': { + '200': { + 'description': 'Successful Response', + 'content': {'application/json': {'schema': {}}} + } + } + } + } + } From 4fe87d40d17fc5d94d02a78051d3a400c01b3ba2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Aug 2024 08:55:45 +0000 Subject: [PATCH 06/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_depends_limiter.py | 99 ++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/tests/test_depends_limiter.py b/tests/test_depends_limiter.py index f7e015b7e..9158a3bb1 100644 --- a/tests/test_depends_limiter.py +++ b/tests/test_depends_limiter.py @@ -1,18 +1,15 @@ from contextlib import asynccontextmanager -import threading +import anyio +import sniffio from fastapi import Depends, FastAPI from fastapi.testclient import TestClient from typing_extensions import Annotated -import anyio -import sniffio def get_borrowed_tokens(): - return { - lname: l.borrowed_tokens - for lname, l in sorted(limiters.items()) - } + return {lname: l.borrowed_tokens for lname, l in sorted(limiters.items())} + def limited_dep(): # run this in the event loop thread: @@ -22,11 +19,12 @@ def limited_dep(): def init_limiters(): return { - 'default': None, # should be set in Lifespan handler - 'a': anyio.CapacityLimiter(5), - 'b': anyio.CapacityLimiter(3), + "default": None, # should be set in Lifespan handler + "a": anyio.CapacityLimiter(5), + "b": anyio.CapacityLimiter(3), } + # Note: # initializing CapacityLimiters at module level before the event loop has started # needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0 @@ -39,16 +37,18 @@ except sniffio.AsyncLibraryNotFoundError: # not in an async context yet async def _init_limiters(): return init_limiters() + limiters = anyio.run(_init_limiters) @asynccontextmanager async def lifespan(app: FastAPI): - limiters['default'] = anyio.to_thread.current_default_thread_limiter() + limiters["default"] = anyio.to_thread.current_default_thread_limiter() yield { - 'limiters': limiters, + "limiters": limiters, } + app = FastAPI(lifespan=lifespan) @@ -56,70 +56,75 @@ app = FastAPI(lifespan=lifespan) async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]): return borrowed_tokens + @app.get("/a") -async def a(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['a'])]): +async def a( + borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["a"])], +): return borrowed_tokens + @app.get("/b") -async def b(borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters['b'])]): +async def b( + borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["b"])], +): return borrowed_tokens - def test_depends_limiter(): with TestClient(app) as client: response = client.get("/") assert response.status_code == 200, response.text - assert response.json() == {"a":0,"b":0,"default":1} + assert response.json() == {"a": 0, "b": 0, "default": 1} response = client.get("/a") assert response.status_code == 200, response.text - assert response.json() == {"a":1,"b":0,"default":0} + assert response.json() == {"a": 1, "b": 0, "default": 0} response = client.get("/b") assert response.status_code == 200, response.text - assert response.json() == {"a":0,"b":1,"default":0} + assert response.json() == {"a": 0, "b": 1, "default": 0} def test_openapi_schema(): with TestClient(app) as client: response = client.get("/openapi.json") assert response.status_code == 200, response.text - assert response.json()['paths'] == { - '/': { - 'get': { - 'summary': 'Root', - 'operationId': 'root__get', - 'responses': { - '200': { - 'description': 'Successful Response', - 'content': {'application/json': {'schema': {}}} + assert response.json()["paths"] == { + "/": { + "get": { + "summary": "Root", + "operationId": "root__get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, } - } + }, } }, - '/a': { - 'get': { - 'summary': 'A', - 'operationId': 'a_a_get', - 'responses': { - '200': { - 'description': 'Successful Response', - 'content': {'application/json': {'schema': {}}} + "/a": { + "get": { + "summary": "A", + "operationId": "a_a_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, } - } + }, } }, - '/b': { - 'get': { - 'summary': 'B', - 'operationId': 'b_b_get', - 'responses': { - '200': { - 'description': 'Successful Response', - 'content': {'application/json': {'schema': {}}} + "/b": { + "get": { + "summary": "B", + "operationId": "b_b_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, } - } + }, } - } + }, } From cdd63bb55514622643d7bd2e1c0ce55aece46bb5 Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Tue, 13 Aug 2024 11:06:25 +0200 Subject: [PATCH 07/10] make flake8 checker happy --- tests/test_depends_limiter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_depends_limiter.py b/tests/test_depends_limiter.py index 9158a3bb1..1da219939 100644 --- a/tests/test_depends_limiter.py +++ b/tests/test_depends_limiter.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated def get_borrowed_tokens(): - return {lname: l.borrowed_tokens for lname, l in sorted(limiters.items())} + return {lname: lobj.borrowed_tokens for lname, lobj in sorted(limiters.items())} def limited_dep(): From bf38db2c4ba4cffd3e36fbc6deb6581265837aeb Mon Sep 17 00:00:00 2001 From: Alexander 'Leo' Bergolth Date: Wed, 4 Sep 2024 16:39:31 +0200 Subject: [PATCH 08/10] type hint changes --- fastapi/concurrency.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 0a7aebc99..83bb9e157 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -19,9 +19,10 @@ _P = ParamSpec("_P") _T = TypeVar("_T") async def run_in_threadpool( - func: typing.Callable[_P, _T], *args: _P.args, - _limiter: Optional[anyio.CapacityLimiter] = None, - **kwargs: _P.kwargs + func: typing.Callable[_P, _T], + *args: typing.Any, + _limiter: Optional[anyio.CapacityLimiter] = None, + **kwargs: typing.Any ) -> _T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here From 32a58515b96a0b8e118e8c507286b08a1055142f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:58:50 +0000 Subject: [PATCH 09/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/concurrency.py | 16 +++++++++------- fastapi/dependencies/models.py | 2 +- fastapi/dependencies/utils.py | 19 +++++++++++++------ fastapi/param_functions.py | 7 ++++--- fastapi/params.py | 13 +++++++++---- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 83bb9e157..132ce4199 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,8 +1,9 @@ -from contextlib import asynccontextmanager as asynccontextmanager -from typing import AsyncGenerator, ContextManager, TypeVar, Optional import functools import sys import typing +from contextlib import asynccontextmanager as asynccontextmanager +from typing import AsyncGenerator, ContextManager, Optional, TypeVar + if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover @@ -18,20 +19,23 @@ from starlette.concurrency import ( # noqa _P = ParamSpec("_P") _T = TypeVar("_T") + async def run_in_threadpool( func: typing.Callable[_P, _T], *args: typing.Any, _limiter: Optional[anyio.CapacityLimiter] = None, - **kwargs: typing.Any + **kwargs: typing.Any, ) -> _T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await anyio.to_thread.run_sync(func, *args, limiter=_limiter) + @asynccontextmanager async def contextmanager_in_threadpool( - cm: ContextManager[_T], limiter: Optional[anyio.CapacityLimiter] = None, + cm: ContextManager[_T], + limiter: Optional[anyio.CapacityLimiter] = None, ) -> AsyncGenerator[_T, None]: # blocking __exit__ from running waiting on a free thread # can create race conditions/deadlocks if the context manager itself @@ -51,6 +55,4 @@ async def contextmanager_in_threadpool( if not ok: raise e else: - await run_in_threadpool( - cm.__exit__, None, None, None, _limiter=exit_limiter - ) + await run_in_threadpool(cm.__exit__, None, None, None, _limiter=exit_limiter) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 0891213e8..a152cf212 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Sequence, Tuple -import anyio +import anyio from fastapi._compat import ModelField from fastapi.security.base import SecurityBase diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96a36de1d..72db9f421 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -533,12 +533,16 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any], + *, + call: Callable[..., Any], + stack: AsyncExitStack, + sub_values: Dict[str, Any], limiter: Optional[anyio.CapacityLimiter] = None, ) -> Any: if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values), - limiter=limiter) + cm = contextmanager_in_threadpool( + contextmanager(call)(**sub_values), limiter=limiter + ) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -615,14 +619,17 @@ async def solve_dependencies( solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values, + call=call, + stack=async_exit_stack, + sub_values=solved_result.values, limiter=sub_dependant.limiter, ) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter, - **solved_result.values) + solved = await run_in_threadpool( + call, _limiter=sub_dependant.limiter, **solved_result.values + ) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 86c3c22e8..7722f3b4c 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import anyio +import anyio from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example @@ -2384,5 +2384,6 @@ def Security( # noqa: N802 return [{"item_id": "Foo", "owner": current_user.username}] ``` """ - return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache, - limiter=limiter) + return params.Security( + dependency=dependency, scopes=scopes, use_cache=use_cache, limiter=limiter + ) diff --git a/fastapi/params.py b/fastapi/params.py index 41494044f..a936e9892 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -1,8 +1,8 @@ import warnings from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import anyio +import anyio from fastapi.openapi.models import Example from pydantic.fields import FieldInfo from typing_extensions import Annotated, deprecated @@ -761,7 +761,9 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, + self, + dependency: Optional[Callable[..., Any]] = None, + *, use_cache: bool = True, limiter: Optional[anyio.CapacityLimiter] = None, ): @@ -772,8 +774,11 @@ class Depends: def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - limiter = f", limiter=CapacityLimiter({self.limiter.total_tokens})" \ - if self.limiter else "" + limiter = ( + f", limiter=CapacityLimiter({self.limiter.total_tokens})" + if self.limiter + else "" + ) return f"{self.__class__.__name__}({attr}{cache}{limiter})" From dd0209864fed72b0575936e8952b605f956f4d6c Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Tue, 22 Jul 2025 09:30:42 +0200 Subject: [PATCH 10/10] Fix import of `run_in_threadpool` --- fastapi/dependencies/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index a3da5a2f5..a559cc3f5 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -61,7 +61,6 @@ from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks -from starlette.concurrency import run_in_threadpool from starlette.datastructures import ( FormData, Headers,