@ -1,6 +1,8 @@
import sys
import sys
from typing import AsyncGenerator , ContextManager , TypeVar
from typing import AsyncGenerator , ContextManager , TypeVar
import anyio
from anyio import CapacityLimiter
from starlette . concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette . concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette . concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette . concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette . concurrency import ( # noqa
from starlette . concurrency import ( # noqa
@ -22,11 +24,24 @@ _T = TypeVar("_T")
async def contextmanager_in_threadpool (
async def contextmanager_in_threadpool (
cm : ContextManager [ _T ] ,
cm : ContextManager [ _T ] ,
) - > AsyncGenerator [ _T , None ] :
) - > AsyncGenerator [ _T , None ] :
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has it's own internal pool (e.g. a database connection pool)
# to avoid this we let __exit__ run without a capacity limit
# since we're creating a new limiter for each call, any non-zero limit
# works (1 is arbitrary)
exit_limiter = CapacityLimiter ( 1 )
try :
try :
yield await run_in_threadpool ( cm . __enter__ )
yield await run_in_threadpool ( cm . __enter__ )
except Exception as e :
except Exception as e :
ok : bool = await run_in_threadpool ( cm . __exit__ , type ( e ) , e , None )
ok = bool (
await anyio . to_thread . run_sync (
cm . __exit__ , type ( e ) , e , None , limiter = exit_limiter
)
)
if not ok :
if not ok :
raise e
raise e
else :
else :
await run_in_threadpool ( cm . __exit__ , None , None , None )
await anyio . to_thread . run_sync (
cm . __exit__ , None , None , None , limiter = exit_limiter
)