From 70323105f975545dbe34b956460c5553a4035aa0 Mon Sep 17 00:00:00 2001 From: Noah Klein Date: Tue, 27 Aug 2024 02:36:04 +0000 Subject: [PATCH] add deadlock test that locks before introducing the fix --- tests/test_depends_deadlock.py | 72 ++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_depends_deadlock.py diff --git a/tests/test_depends_deadlock.py b/tests/test_depends_deadlock.py new file mode 100644 index 000000000..7e71d4663 --- /dev/null +++ b/tests/test_depends_deadlock.py @@ -0,0 +1,72 @@ +import pytest +from fastapi import FastAPI, Depends +from fastapi.testclient import TestClient +from pydantic import BaseModel +from typing import Generator +import threading +import time +import httpx +import asyncio + +app = FastAPI() + +# Dummy pydantic model +class Item(BaseModel): + name: str + id: int + +# Mutex, acting as our "connection pool" for a database for example +mutex_db_connection_pool = threading.Lock() + +# Simulate a database class that uses a connection pool to manage +# active clients for a db. The client needs to perform blocking +# calls to connect and disconnect from the db +class MyDB: + def __init__(self): + self.lock_acquired = False + + def connect(self): + mutex_db_connection_pool.acquire() + self.lock_acquired = True + # Sleep to simulate some blocking IO like connecting to a db + time.sleep(.001) + + def disconnect(self): + if self.lock_acquired: + # Use a sleep to simulate some blocking IO such as a db disconnect + time.sleep(.001) + mutex_db_connection_pool.release() + self.lock_acquired = False + +# Simulate getting a connection to a database from a connection pool +# using the mutex to act as this limited resource +def get_db() -> Generator[MyDB, None, None]: + my_db = MyDB() + try: + yield my_db + finally: + my_db.disconnect() + +# An endpoint that uses Depends for resource management and also includes +# a response_model definition would previously deadlock in the validation +# of the model and the cleanup of the Depends +@app.get("/deadlock", response_model=Item) +def get_deadlock(db: MyDB = Depends(get_db)): + db.connect() + return Item(name="foo", id=1) + +# Fire off 100 requests in parallel(ish) in order to create contention +# over the shared resource (simulating a fastapi server that interacts with +# a database connection pool). After the patch, each thread on the server is +# able to free the resource without deadlocking, allowing each request to +# be handled timely +def test_depends_deadlock_patch(): + async def make_request(client: httpx.AsyncClient): + response = await client.get("/deadlock") + + async def run_requests(): + async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: + tasks = [make_request(aclient) for _ in range(100)] + await asyncio.gather(*tasks) + + asyncio.run(run_requests()) \ No newline at end of file