import sys
from typing import AsyncGenerator, ContextManager, TypeVar

from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool  # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool  # noqa
from starlette.concurrency import (  # noqa
    run_until_first_complete as run_until_first_complete,
)

if sys.version_info >= (3, 7):
    from contextlib import AsyncExitStack as AsyncExitStack
    from contextlib import asynccontextmanager as asynccontextmanager
else:
    from contextlib2 import AsyncExitStack as AsyncExitStack  # noqa
    from contextlib2 import asynccontextmanager as asynccontextmanager  # noqa


_T = TypeVar("_T")


@asynccontextmanager
async def contextmanager_in_threadpool(
    cm: ContextManager[_T],
) -> AsyncGenerator[_T, None]:
    try:
        yield await run_in_threadpool(cm.__enter__)
    except Exception as e:
        ok: bool = await run_in_threadpool(cm.__exit__, type(e), e, None)
        if not ok:
            raise e
    else:
        await run_in_threadpool(cm.__exit__, None, None, None)