committed by
GitHub
4 changed files with 159 additions and 17 deletions
@ -0,0 +1,68 @@ |
|||
import contextlib |
|||
import time |
|||
from collections.abc import Iterator |
|||
|
|||
import anyio.to_thread |
|||
import pytest |
|||
from anyio import CapacityLimiter |
|||
from fastapi import concurrency |
|||
|
|||
|
|||
@pytest.fixture |
|||
def reset_teardown_limiter(monkeypatch: pytest.MonkeyPatch) -> None: |
|||
"""Reset the teardown limiter before/after tests to avoid interference |
|||
between different anyio backends.""" |
|||
monkeypatch.setattr(concurrency, "_teardown_limiter", CapacityLimiter(5)) |
|||
|
|||
|
|||
@pytest.mark.anyio |
|||
@pytest.mark.usefixtures("reset_teardown_limiter") |
|||
async def test_run_in_teardown_threadpool() -> None: |
|||
def func(x: int, y: int) -> int: |
|||
return x + y |
|||
|
|||
result = await concurrency.run_in_teardown_threadpool(func, 1, y=2) |
|||
assert result == 3 |
|||
|
|||
|
|||
@pytest.mark.anyio |
|||
@pytest.mark.usefixtures("reset_teardown_limiter") |
|||
async def test_contextmanager_in_threadpool() -> None: |
|||
@contextlib.contextmanager |
|||
def context_manager() -> Iterator[str]: |
|||
yield "entered" |
|||
|
|||
async with concurrency.contextmanager_in_threadpool(context_manager()) as result: |
|||
assert result == "entered" |
|||
|
|||
|
|||
@pytest.mark.anyio |
|||
@pytest.mark.usefixtures("reset_teardown_limiter") |
|||
async def test_competing_acquire_release() -> None: |
|||
"""Check that the main threadpool does not block the teardown threadpool.""" |
|||
pool_size = anyio.to_thread.current_default_thread_limiter().total_tokens |
|||
acquirable = False |
|||
acquired = [] |
|||
|
|||
def acquire() -> None: |
|||
while not acquirable: |
|||
time.sleep(0.001) |
|||
acquired.append(True) |
|||
|
|||
def release() -> bool: |
|||
nonlocal acquirable |
|||
time.sleep(0.001) |
|||
acquirable = True |
|||
return acquirable |
|||
|
|||
async with anyio.create_task_group() as tg: |
|||
for _ in range(pool_size): |
|||
tg.start_soon(concurrency.run_in_threadpool, acquire) |
|||
|
|||
await anyio.sleep(0.001) |
|||
|
|||
# The threadpool should now be full of threads waiting to acquire |
|||
# The release function should be able to run without being blocked by acquires |
|||
await concurrency.run_in_teardown_threadpool(release) |
|||
|
|||
assert len(acquired) == pool_size |
|||
@ -0,0 +1,56 @@ |
|||
import asyncio |
|||
import threading |
|||
import time |
|||
from collections.abc import Iterator |
|||
|
|||
from fastapi import Depends, FastAPI |
|||
from httpx import ASGITransport, AsyncClient |
|||
from pydantic import BaseModel |
|||
|
|||
# Mutex, and dependency acting as our "connection pool" for a database for example |
|||
mutex = threading.Lock() |
|||
|
|||
|
|||
# Simulate releaasing a pooled resource in the teardown of a Depends, |
|||
# which in reality is usually a database connection or similar. |
|||
def release_resource() -> Iterator[None]: |
|||
try: |
|||
time.sleep(0.001) |
|||
yield |
|||
finally: |
|||
time.sleep(0.001) |
|||
mutex.release() |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
class Item(BaseModel): |
|||
name: str |
|||
id: int |
|||
|
|||
|
|||
# An endpoint that uses Depends for resource management and also includes |
|||
# a response_model definition would previously deadlock in the validation |
|||
# of the model and the cleanup of the Depends |
|||
@app.get("/deadlock", response_model=Item) |
|||
def get_deadlock(dep: None = Depends(release_resource)) -> Item: |
|||
mutex.acquire() |
|||
return Item(name="foo", id=1) |
|||
|
|||
|
|||
# Fire off 100 requests in parallel(ish) in order to create contention |
|||
# over the shared resource (simulating a fastapi server that interacts with |
|||
# a database connection pool). |
|||
def test_depends_deadlock() -> None: |
|||
async def make_request(client: AsyncClient): |
|||
await client.get("/deadlock") |
|||
|
|||
async def run_requests() -> None: |
|||
async with AsyncClient( |
|||
transport=ASGITransport(app=app), base_url="http://testserver" |
|||
) as aclient: |
|||
tasks = [make_request(aclient) for _ in range(100)] |
|||
await asyncio.gather(*tasks) |
|||
|
|||
asyncio.run(run_requests()) |
|||
Loading…
Reference in new issue