Browse Source

run format.sh

pull/12066/head
Noah Klein 7 months ago
parent
commit
a1b1008548
  1. 25
      tests/test_depends_deadlock.py

25
tests/test_depends_deadlock.py

@ -1,23 +1,25 @@
import pytest
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing import Generator
import asyncio
import threading
import time
from typing import Generator
import httpx
import asyncio
from fastapi import Depends, FastAPI
from pydantic import BaseModel
app = FastAPI()
# Dummy pydantic model
class Item(BaseModel):
name: str
id: int
# Mutex, acting as our "connection pool" for a database for example
mutex_db_connection_pool = threading.Lock()
# Simulate a database class that uses a connection pool to manage
# active clients for a db. The client needs to perform blocking
# calls to connect and disconnect from the db
@ -29,15 +31,16 @@ class MyDB:
mutex_db_connection_pool.acquire()
self.lock_acquired = True
# Sleep to simulate some blocking IO like connecting to a db
time.sleep(.001)
time.sleep(0.001)
def disconnect(self):
if self.lock_acquired:
# Use a sleep to simulate some blocking IO such as a db disconnect
time.sleep(.001)
time.sleep(0.001)
mutex_db_connection_pool.release()
self.lock_acquired = False
# Simulate getting a connection to a database from a connection pool
# using the mutex to act as this limited resource
def get_db() -> Generator[MyDB, None, None]:
@ -47,6 +50,7 @@ def get_db() -> Generator[MyDB, None, None]:
finally:
my_db.disconnect()
# An endpoint that uses Depends for resource management and also includes
# a response_model definition would previously deadlock in the validation
# of the model and the cleanup of the Depends
@ -55,6 +59,7 @@ def get_deadlock(db: MyDB = Depends(get_db)):
db.connect()
return Item(name="foo", id=1)
# Fire off 100 requests in parallel(ish) in order to create contention
# over the shared resource (simulating a fastapi server that interacts with
# a database connection pool). After the patch, each thread on the server is
@ -62,11 +67,11 @@ def get_deadlock(db: MyDB = Depends(get_db)):
# be handled timely
def test_depends_deadlock_patch():
async def make_request(client: httpx.AsyncClient):
response = await client.get("/deadlock")
await client.get("/deadlock")
async def run_requests():
async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient:
tasks = [make_request(aclient) for _ in range(100)]
await asyncio.gather(*tasks)
asyncio.run(run_requests())
asyncio.run(run_requests())

Loading…
Cancel
Save