diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index aa429b617..25aa705f4 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager as asynccontextmanager -from typing import AsyncGenerator, ContextManager, TypeVar +from typing import AsyncGenerator, ContextManager, TypeVar, Optional import functools import sys import typing @@ -21,7 +21,7 @@ _T = TypeVar("_T") async def run_in_threadpool( func: typing.Callable[_P, _T], *args: _P.args, - _limiter: anyio.CapacityLimiter | None = None, + _limiter: Optional[anyio.CapacityLimiter] = None, **kwargs: _P.kwargs ) -> _T: if kwargs: # pragma: no cover @@ -31,7 +31,7 @@ async def run_in_threadpool( @asynccontextmanager async def contextmanager_in_threadpool( - cm: ContextManager[_T], + cm: ContextManager[_T], limiter: Optional[anyio.CapacityLimiter] = None, ) -> AsyncGenerator[_T, None]: # blocking __exit__ from running waiting on a free thread # can create race conditions/deadlocks if the context manager itself @@ -41,16 +41,16 @@ async def contextmanager_in_threadpool( # works (1 is arbitrary) exit_limiter = CapacityLimiter(1) try: - yield await run_in_threadpool(cm.__enter__) + yield await run_in_threadpool(cm.__enter__, _limiter=limiter) except Exception as e: ok = bool( - await anyio.to_thread.run_sync( - cm.__exit__, type(e), e, None, limiter=exit_limiter + await run_in_threadpool( + cm.__exit__, type(e), e, None, _limiter=exit_limiter ) ) if not ok: raise e else: - await anyio.to_thread.run_sync( - cm.__exit__, None, None, None, limiter=exit_limiter + await run_in_threadpool( + cm.__exit__, None, None, None, _limiter=exit_limiter )