You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.1 KiB

import contextlib
import time
from collections.abc import Iterator
import anyio.to_thread
import pytest
from anyio import CapacityLimiter
from fastapi import concurrency
@pytest.fixture
def reset_teardown_limiter(monkeypatch: pytest.MonkeyPatch) -> None:
"""Reset the teardown limiter before/after tests to avoid interference
between different anyio backends."""
monkeypatch.setattr(concurrency, "_teardown_limiter", CapacityLimiter(5))
@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_run_in_teardown_threadpool() -> None:
def func(x: int, y: int) -> int:
return x + y
result = await concurrency.run_in_teardown_threadpool(func, 1, y=2)
assert result == 3
@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_contextmanager_in_threadpool() -> None:
@contextlib.contextmanager
def context_manager() -> Iterator[str]:
yield "entered"
async with concurrency.contextmanager_in_threadpool(context_manager()) as result:
assert result == "entered"
@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_competing_acquire_release() -> None:
"""Check that the main threadpool does not block the teardown threadpool."""
pool_size = anyio.to_thread.current_default_thread_limiter().total_tokens
acquirable = False
acquired = []
def acquire() -> None:
while not acquirable:
time.sleep(0.001)
acquired.append(True)
def release() -> bool:
nonlocal acquirable
time.sleep(0.001)
acquirable = True
return acquirable
async with anyio.create_task_group() as tg:
for _ in range(pool_size):
tg.start_soon(concurrency.run_in_threadpool, acquire)
await anyio.sleep(0.001)
# The threadpool should now be full of threads waiting to acquire
# The release function should be able to run without being blocked by acquires
await concurrency.run_in_teardown_threadpool(release)
assert len(acquired) == pool_size