diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 894bd3ed1..aa429b617 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,5 +1,12 @@ from contextlib import asynccontextmanager as asynccontextmanager from typing import AsyncGenerator, ContextManager, TypeVar +import functools +import sys +import typing +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec import anyio from anyio import CapacityLimiter @@ -9,8 +16,18 @@ from starlette.concurrency import ( # noqa run_until_first_complete as run_until_first_complete, ) +_P = ParamSpec("_P") _T = TypeVar("_T") +async def run_in_threadpool( + func: typing.Callable[_P, _T], *args: _P.args, + _limiter: anyio.CapacityLimiter | None = None, + **kwargs: _P.kwargs +) -> _T: + if kwargs: # pragma: no cover + # run_sync doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await anyio.to_thread.run_sync(func, *args, limiter=_limiter) @asynccontextmanager async def contextmanager_in_threadpool(