diff --git a/tests/test_depends_deadlock.py b/tests/test_depends_deadlock.py index 7e71d4663..268ae316e 100644 --- a/tests/test_depends_deadlock.py +++ b/tests/test_depends_deadlock.py @@ -1,23 +1,25 @@ -import pytest -from fastapi import FastAPI, Depends -from fastapi.testclient import TestClient -from pydantic import BaseModel -from typing import Generator +import asyncio import threading import time +from typing import Generator + import httpx -import asyncio +from fastapi import Depends, FastAPI +from pydantic import BaseModel app = FastAPI() + # Dummy pydantic model class Item(BaseModel): name: str id: int + # Mutex, acting as our "connection pool" for a database for example mutex_db_connection_pool = threading.Lock() + # Simulate a database class that uses a connection pool to manage # active clients for a db. The client needs to perform blocking # calls to connect and disconnect from the db @@ -29,15 +31,16 @@ class MyDB: mutex_db_connection_pool.acquire() self.lock_acquired = True # Sleep to simulate some blocking IO like connecting to a db - time.sleep(.001) + time.sleep(0.001) def disconnect(self): if self.lock_acquired: # Use a sleep to simulate some blocking IO such as a db disconnect - time.sleep(.001) + time.sleep(0.001) mutex_db_connection_pool.release() self.lock_acquired = False + # Simulate getting a connection to a database from a connection pool # using the mutex to act as this limited resource def get_db() -> Generator[MyDB, None, None]: @@ -47,6 +50,7 @@ def get_db() -> Generator[MyDB, None, None]: finally: my_db.disconnect() + # An endpoint that uses Depends for resource management and also includes # a response_model definition would previously deadlock in the validation # of the model and the cleanup of the Depends @@ -55,6 +59,7 @@ def get_deadlock(db: MyDB = Depends(get_db)): db.connect() return Item(name="foo", id=1) + # Fire off 100 requests in parallel(ish) in order to create contention # over the shared resource (simulating a fastapi server that interacts with # a database connection pool). After the patch, each thread on the server is @@ -62,11 +67,11 @@ def get_deadlock(db: MyDB = Depends(get_db)): # be handled timely def test_depends_deadlock_patch(): async def make_request(client: httpx.AsyncClient): - response = await client.get("/deadlock") + await client.get("/deadlock") async def run_requests(): async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: tasks = [make_request(aclient) for _ in range(100)] await asyncio.gather(*tasks) - asyncio.run(run_requests()) \ No newline at end of file + asyncio.run(run_requests())