|
@ -1,23 +1,25 @@ |
|
|
import pytest |
|
|
import asyncio |
|
|
from fastapi import FastAPI, Depends |
|
|
|
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
from typing import Generator |
|
|
|
|
|
import threading |
|
|
import threading |
|
|
import time |
|
|
import time |
|
|
|
|
|
from typing import Generator |
|
|
|
|
|
|
|
|
import httpx |
|
|
import httpx |
|
|
import asyncio |
|
|
from fastapi import Depends, FastAPI |
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Dummy pydantic model |
|
|
# Dummy pydantic model |
|
|
class Item(BaseModel): |
|
|
class Item(BaseModel): |
|
|
name: str |
|
|
name: str |
|
|
id: int |
|
|
id: int |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Mutex, acting as our "connection pool" for a database for example |
|
|
# Mutex, acting as our "connection pool" for a database for example |
|
|
mutex_db_connection_pool = threading.Lock() |
|
|
mutex_db_connection_pool = threading.Lock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Simulate a database class that uses a connection pool to manage |
|
|
# Simulate a database class that uses a connection pool to manage |
|
|
# active clients for a db. The client needs to perform blocking |
|
|
# active clients for a db. The client needs to perform blocking |
|
|
# calls to connect and disconnect from the db |
|
|
# calls to connect and disconnect from the db |
|
@ -29,15 +31,16 @@ class MyDB: |
|
|
mutex_db_connection_pool.acquire() |
|
|
mutex_db_connection_pool.acquire() |
|
|
self.lock_acquired = True |
|
|
self.lock_acquired = True |
|
|
# Sleep to simulate some blocking IO like connecting to a db |
|
|
# Sleep to simulate some blocking IO like connecting to a db |
|
|
time.sleep(.001) |
|
|
time.sleep(0.001) |
|
|
|
|
|
|
|
|
def disconnect(self): |
|
|
def disconnect(self): |
|
|
if self.lock_acquired: |
|
|
if self.lock_acquired: |
|
|
# Use a sleep to simulate some blocking IO such as a db disconnect |
|
|
# Use a sleep to simulate some blocking IO such as a db disconnect |
|
|
time.sleep(.001) |
|
|
time.sleep(0.001) |
|
|
mutex_db_connection_pool.release() |
|
|
mutex_db_connection_pool.release() |
|
|
self.lock_acquired = False |
|
|
self.lock_acquired = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Simulate getting a connection to a database from a connection pool |
|
|
# Simulate getting a connection to a database from a connection pool |
|
|
# using the mutex to act as this limited resource |
|
|
# using the mutex to act as this limited resource |
|
|
def get_db() -> Generator[MyDB, None, None]: |
|
|
def get_db() -> Generator[MyDB, None, None]: |
|
@ -47,6 +50,7 @@ def get_db() -> Generator[MyDB, None, None]: |
|
|
finally: |
|
|
finally: |
|
|
my_db.disconnect() |
|
|
my_db.disconnect() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# An endpoint that uses Depends for resource management and also includes |
|
|
# An endpoint that uses Depends for resource management and also includes |
|
|
# a response_model definition would previously deadlock in the validation |
|
|
# a response_model definition would previously deadlock in the validation |
|
|
# of the model and the cleanup of the Depends |
|
|
# of the model and the cleanup of the Depends |
|
@ -55,6 +59,7 @@ def get_deadlock(db: MyDB = Depends(get_db)): |
|
|
db.connect() |
|
|
db.connect() |
|
|
return Item(name="foo", id=1) |
|
|
return Item(name="foo", id=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Fire off 100 requests in parallel(ish) in order to create contention |
|
|
# Fire off 100 requests in parallel(ish) in order to create contention |
|
|
# over the shared resource (simulating a fastapi server that interacts with |
|
|
# over the shared resource (simulating a fastapi server that interacts with |
|
|
# a database connection pool). After the patch, each thread on the server is |
|
|
# a database connection pool). After the patch, each thread on the server is |
|
@ -62,11 +67,11 @@ def get_deadlock(db: MyDB = Depends(get_db)): |
|
|
# be handled timely |
|
|
# be handled timely |
|
|
def test_depends_deadlock_patch(): |
|
|
def test_depends_deadlock_patch(): |
|
|
async def make_request(client: httpx.AsyncClient): |
|
|
async def make_request(client: httpx.AsyncClient): |
|
|
response = await client.get("/deadlock") |
|
|
await client.get("/deadlock") |
|
|
|
|
|
|
|
|
async def run_requests(): |
|
|
async def run_requests(): |
|
|
async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: |
|
|
async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: |
|
|
tasks = [make_request(aclient) for _ in range(100)] |
|
|
tasks = [make_request(aclient) for _ in range(100)] |
|
|
await asyncio.gather(*tasks) |
|
|
await asyncio.gather(*tasks) |
|
|
|
|
|
|
|
|
asyncio.run(run_requests()) |
|
|
asyncio.run(run_requests()) |
|
|