You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

55 lines
2.1 KiB

import functools
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractContextManager
from contextlib import asynccontextmanager as asynccontextmanager
from typing import ParamSpec, TypeVar
import anyio.to_thread
from anyio import CapacityLimiter
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)
_P = ParamSpec("_P")
_T = TypeVar("_T")
# Blocking __exit__ and other teardown operations from running can create race
# conditions/deadlocks if the context manager itself has its own internal pool
# (e.g. a database connection pool).
# To avoid this maintain a separate limiter for teardown operations, so that the
# operations acquiring resources can never block operations releasing resources.
# NOTE: 5 is arbitrary, we would like more than 1 so that teardowns are not serialised.
_teardown_limiter = CapacityLimiter(5)
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: AbstractContextManager[_T],
) -> AsyncGenerator[_T, None]:
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(
await run_in_teardown_threadpool(cm.__exit__, type(e), e, e.__traceback__)
)
if not ok:
raise e
else:
await run_in_teardown_threadpool(cm.__exit__, None, None, None)
async def run_in_teardown_threadpool(
func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
"""Run a function in the separate teardown threadpool.
This will run the function in the teardown threadpool in order to avoid it
being blocked by other operations waiting to acquire resources.
Unless you know what you are doing, you probably don't want this function,
use run_in_threadpool instead.
"""
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func, limiter=_teardown_limiter)