From 3c29f0f97c0335032ee0e6059e7998ba073dc438 Mon Sep 17 00:00:00 2001 From: Noah Klein Date: Sat, 24 Aug 2024 22:51:17 +0000 Subject: [PATCH 1/4] Continue patch from #5122 by not limiting response model validation --- fastapi/routing.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 49f1b6013..a78bfc0bc 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -20,6 +20,9 @@ from typing import ( Type, Union, ) +import anyio +from anyio import CapacityLimiter +import functools from fastapi import params from fastapi._compat import ( @@ -163,9 +166,10 @@ async def serialize_response( if is_coroutine: value, errors_ = field.validate(response_content, {}, loc=("response",)) else: - value, errors_ = await run_in_threadpool( - field.validate, response_content, {}, loc=("response",) - ) + # Run without a capacity limit for similar reasons as marked in fastapi/concurrency.py + exit_limiter = CapacityLimiter(1) + validate_func = functools.partial(field.validate, loc=("response",)) + value, errors_ = await anyio.to_thread.run_sync(validate_func, response_content, {}, limiter=exit_limiter) if isinstance(errors_, list): errors.extend(errors_) elif errors_: From 4cdc84db264b072d6aa7bf25cc4a4ccf28a25be8 Mon Sep 17 00:00:00 2001 From: Noah Klein Date: Sat, 24 Aug 2024 22:52:16 +0000 Subject: [PATCH 2/4] run scripts/format.sh --- fastapi/routing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index a78bfc0bc..fe37eb548 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,6 +1,7 @@ import asyncio import dataclasses import email.message +import functools import inspect import json from contextlib import AsyncExitStack, asynccontextmanager @@ -20,10 +21,9 @@ from typing import ( Type, Union, ) + import anyio from anyio import CapacityLimiter -import functools - from fastapi import params from fastapi._compat import ( ModelField, @@ -169,7 +169,9 @@ async def serialize_response( # Run without a capacity limit for similar reasons as marked in fastapi/concurrency.py exit_limiter = CapacityLimiter(1) validate_func = functools.partial(field.validate, loc=("response",)) - value, errors_ = await anyio.to_thread.run_sync(validate_func, response_content, {}, limiter=exit_limiter) + value, errors_ = await anyio.to_thread.run_sync( + validate_func, response_content, {}, limiter=exit_limiter + ) if isinstance(errors_, list): errors.extend(errors_) elif errors_: From 70323105f975545dbe34b956460c5553a4035aa0 Mon Sep 17 00:00:00 2001 From: Noah Klein Date: Tue, 27 Aug 2024 02:36:04 +0000 Subject: [PATCH 3/4] add deadlock test that locks before introducing the fix --- tests/test_depends_deadlock.py | 72 ++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_depends_deadlock.py diff --git a/tests/test_depends_deadlock.py b/tests/test_depends_deadlock.py new file mode 100644 index 000000000..7e71d4663 --- /dev/null +++ b/tests/test_depends_deadlock.py @@ -0,0 +1,72 @@ +import pytest +from fastapi import FastAPI, Depends +from fastapi.testclient import TestClient +from pydantic import BaseModel +from typing import Generator +import threading +import time +import httpx +import asyncio + +app = FastAPI() + +# Dummy pydantic model +class Item(BaseModel): + name: str + id: int + +# Mutex, acting as our "connection pool" for a database for example +mutex_db_connection_pool = threading.Lock() + +# Simulate a database class that uses a connection pool to manage +# active clients for a db. The client needs to perform blocking +# calls to connect and disconnect from the db +class MyDB: + def __init__(self): + self.lock_acquired = False + + def connect(self): + mutex_db_connection_pool.acquire() + self.lock_acquired = True + # Sleep to simulate some blocking IO like connecting to a db + time.sleep(.001) + + def disconnect(self): + if self.lock_acquired: + # Use a sleep to simulate some blocking IO such as a db disconnect + time.sleep(.001) + mutex_db_connection_pool.release() + self.lock_acquired = False + +# Simulate getting a connection to a database from a connection pool +# using the mutex to act as this limited resource +def get_db() -> Generator[MyDB, None, None]: + my_db = MyDB() + try: + yield my_db + finally: + my_db.disconnect() + +# An endpoint that uses Depends for resource management and also includes +# a response_model definition would previously deadlock in the validation +# of the model and the cleanup of the Depends +@app.get("/deadlock", response_model=Item) +def get_deadlock(db: MyDB = Depends(get_db)): + db.connect() + return Item(name="foo", id=1) + +# Fire off 100 requests in parallel(ish) in order to create contention +# over the shared resource (simulating a fastapi server that interacts with +# a database connection pool). After the patch, each thread on the server is +# able to free the resource without deadlocking, allowing each request to +# be handled timely +def test_depends_deadlock_patch(): + async def make_request(client: httpx.AsyncClient): + response = await client.get("/deadlock") + + async def run_requests(): + async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: + tasks = [make_request(aclient) for _ in range(100)] + await asyncio.gather(*tasks) + + asyncio.run(run_requests()) \ No newline at end of file From a1b1008548b27fc90a13e3720b14ded04c3e3922 Mon Sep 17 00:00:00 2001 From: Noah Klein Date: Tue, 27 Aug 2024 02:36:48 +0000 Subject: [PATCH 4/4] run format.sh --- tests/test_depends_deadlock.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_depends_deadlock.py b/tests/test_depends_deadlock.py index 7e71d4663..268ae316e 100644 --- a/tests/test_depends_deadlock.py +++ b/tests/test_depends_deadlock.py @@ -1,23 +1,25 @@ -import pytest -from fastapi import FastAPI, Depends -from fastapi.testclient import TestClient -from pydantic import BaseModel -from typing import Generator +import asyncio import threading import time +from typing import Generator + import httpx -import asyncio +from fastapi import Depends, FastAPI +from pydantic import BaseModel app = FastAPI() + # Dummy pydantic model class Item(BaseModel): name: str id: int + # Mutex, acting as our "connection pool" for a database for example mutex_db_connection_pool = threading.Lock() + # Simulate a database class that uses a connection pool to manage # active clients for a db. The client needs to perform blocking # calls to connect and disconnect from the db @@ -29,15 +31,16 @@ class MyDB: mutex_db_connection_pool.acquire() self.lock_acquired = True # Sleep to simulate some blocking IO like connecting to a db - time.sleep(.001) + time.sleep(0.001) def disconnect(self): if self.lock_acquired: # Use a sleep to simulate some blocking IO such as a db disconnect - time.sleep(.001) + time.sleep(0.001) mutex_db_connection_pool.release() self.lock_acquired = False + # Simulate getting a connection to a database from a connection pool # using the mutex to act as this limited resource def get_db() -> Generator[MyDB, None, None]: @@ -47,6 +50,7 @@ def get_db() -> Generator[MyDB, None, None]: finally: my_db.disconnect() + # An endpoint that uses Depends for resource management and also includes # a response_model definition would previously deadlock in the validation # of the model and the cleanup of the Depends @@ -55,6 +59,7 @@ def get_deadlock(db: MyDB = Depends(get_db)): db.connect() return Item(name="foo", id=1) + # Fire off 100 requests in parallel(ish) in order to create contention # over the shared resource (simulating a fastapi server that interacts with # a database connection pool). After the patch, each thread on the server is @@ -62,11 +67,11 @@ def get_deadlock(db: MyDB = Depends(get_db)): # be handled timely def test_depends_deadlock_patch(): async def make_request(client: httpx.AsyncClient): - response = await client.get("/deadlock") + await client.get("/deadlock") async def run_requests(): async with httpx.AsyncClient(app=app, base_url="http://testserver") as aclient: tasks = [make_request(aclient) for _ in range(100)] await asyncio.gather(*tasks) - asyncio.run(run_requests()) \ No newline at end of file + asyncio.run(run_requests())