|
|
@ -1,5 +1,12 @@ |
|
|
|
from contextlib import asynccontextmanager as asynccontextmanager |
|
|
|
from typing import AsyncGenerator, ContextManager, TypeVar |
|
|
|
import functools |
|
|
|
import sys |
|
|
|
import typing |
|
|
|
if sys.version_info >= (3, 10): # pragma: no cover |
|
|
|
from typing import ParamSpec |
|
|
|
else: # pragma: no cover |
|
|
|
from typing_extensions import ParamSpec |
|
|
|
|
|
|
|
import anyio |
|
|
|
from anyio import CapacityLimiter |
|
|
@ -9,8 +16,18 @@ from starlette.concurrency import ( # noqa |
|
|
|
run_until_first_complete as run_until_first_complete, |
|
|
|
) |
|
|
|
|
|
|
|
_P = ParamSpec("_P") |
|
|
|
_T = TypeVar("_T") |
|
|
|
|
|
|
|
async def run_in_threadpool( |
|
|
|
func: typing.Callable[_P, _T], *args: _P.args, |
|
|
|
_limiter: anyio.CapacityLimiter | None = None, |
|
|
|
**kwargs: _P.kwargs |
|
|
|
) -> _T: |
|
|
|
if kwargs: # pragma: no cover |
|
|
|
# run_sync doesn't accept 'kwargs', so bind them in here |
|
|
|
func = functools.partial(func, **kwargs) |
|
|
|
return await anyio.to_thread.run_sync(func, *args, limiter=_limiter) |
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
|
async def contextmanager_in_threadpool( |
|
|
|