Browse Source

run format.sh

pull/12066/head
Noah Klein 7 months ago
parent
commit
a1b1008548
  1. 25
      tests/test_depends_deadlock.py

25
tests/test_depends_deadlock.py

@ -1,23 +1,25 @@
import pytest import asyncio
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing import Generator
import threading import threading
import time import time
from typing import Generator
import httpx import httpx
import asyncio from fastapi import Depends, FastAPI
from pydantic import BaseModel
app = FastAPI() app = FastAPI()
# Dummy pydantic model # Dummy pydantic model
class Item(BaseModel): class Item(BaseModel):
name: str name: str
id: int id: int
# Mutex, acting as our "connection pool" for a database for example # Mutex, acting as our "connection pool" for a database for example
mutex_db_connection_pool = threading.Lock() mutex_db_connection_pool = threading.Lock()
# Simulate a database class that uses a connection pool to manage # Simulate a database class that uses a connection pool to manage
# active clients for a db. The client needs to perform blocking # active clients for a db. The client needs to perform blocking
# calls to connect and disconnect from the db # calls to connect and disconnect from the db
@ -29,15 +31,16 @@ class MyDB:
mutex_db_connection_pool.acquire() mutex_db_connection_pool.acquire()
self.lock_acquired = True self.lock_acquired = True
# Sleep to simulate some blocking IO like connecting to a db # Sleep to simulate some blocking IO like connecting to a db
time.sleep(.001) time.sleep(0.001)
def disconnect(self): def disconnect(self):
if self.lock_acquired: if self.lock_acquired:
# Use a sleep to simulate some blocking IO such as a db disconnect # Use a sleep to simulate some blocking IO such as a db disconnect
time.sleep(.001) time.sleep(0.001)
mutex_db_connection_pool.release() mutex_db_connection_pool.release()
self.lock_acquired = False self.lock_acquired = False
# Simulate getting a connection to a database from a connection pool # Simulate getting a connection to a database from a connection pool
# using the mutex to act as this limited resource # using the mutex to act as this limited resource
def get_db() -> Generator[MyDB, None, None]: def get_db() -> Generator[MyDB, None, None]:
@ -47,6 +50,7 @@ def get_db() -> Generator[MyDB, None, None]:
finally: finally:
my_db.disconnect() my_db.disconnect()
# An endpoint that uses Depends for resource management and also includes # An endpoint that uses Depends for resource management and also includes
# a response_model definition would previously deadlock in the validation # a response_model definition would previously deadlock in the validation
# of the model and the cleanup of the Depends # of the model and the cleanup of the Depends
@ -55,6 +59,7 @@ def get_deadlock(db: MyDB = Depends(get_db)):
db.connect() db.connect()
return Item(name="foo", id=1) return Item(name="foo", id=1)
# Fire off 100 requests in parallel(ish) in order to create contention # Fire off 100 requests in parallel(ish) in order to create contention
# over the shared resource (simulating a fastapi server that interacts with # over the shared resource (simulating a fastapi server that interacts with
# a database connection pool). After the patch, each thread on the server is # a database connection pool). After the patch, each thread on the server is
@ -62,11 +67,11 @@ def get_deadlock(db: MyDB = Depends(get_db)):
# be handled timely # be handled timely
def test_depends_deadlock_patch(): def test_depends_deadlock_patch():
async def make_request(client: httpx.AsyncClient): async def make_request(client: httpx.AsyncClient):
response = await client.get("/deadlock") await client.get("/deadlock")
async def run_requests(): async def run_requests():
async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient:
tasks = [make_request(aclient) for _ in range(100)] tasks = [make_request(aclient) for _ in range(100)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
asyncio.run(run_requests()) asyncio.run(run_requests())

Loading…
Cancel
Save