Browse Source

Merge 3140dc8b2a into 460f8d2cc8

pull/15388/merge
Oliver Margetts 11 hours ago
committed by GitHub
parent
commit
852928f80a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 44
      fastapi/concurrency.py
  2. 8
      fastapi/routing.py
  3. 68
      tests/test_concurrency.py
  4. 56
      tests/test_depends_deadlock.py

44
fastapi/concurrency.py

@ -1,7 +1,8 @@
from collections.abc import AsyncGenerator
import functools
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractContextManager
from contextlib import asynccontextmanager as asynccontextmanager
from typing import TypeVar
from typing import ParamSpec, TypeVar
import anyio.to_thread
from anyio import CapacityLimiter
@ -11,31 +12,44 @@ from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)
_P = ParamSpec("_P")
_T = TypeVar("_T")
# Blocking __exit__ and other teardown operations from running can create race
# conditions/deadlocks if the context manager itself has its own internal pool
# (e.g. a database connection pool).
# To avoid this maintain a separate limiter for teardown operations, so that the
# operations acquiring resources can never block operations releasing resources.
# NOTE: 5 is arbitrary, we would like more than 1 so that teardowns are not serialised.
_teardown_limiter = CapacityLimiter(5)
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: AbstractContextManager[_T],
) -> AsyncGenerator[_T, None]:
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool)
# to avoid this we let __exit__ run without a capacity limit
# since we're creating a new limiter for each call, any non-zero limit
# works (1 is arbitrary)
exit_limiter = CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
)
await run_in_teardown_threadpool(cm.__exit__, type(e), e, e.__traceback__)
)
if not ok:
raise e
else:
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)
await run_in_teardown_threadpool(cm.__exit__, None, None, None)
async def run_in_teardown_threadpool(
func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
"""Run a function in the separate teardown threadpool.
This will run the function in the teardown threadpool in order to avoid it
being blocked by other operations waiting to acquire resources.
Unless you know what you are doing, you probably don't want this function,
use run_in_threadpool instead.
"""
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func, limiter=_teardown_limiter)

8
fastapi/routing.py

@ -38,6 +38,11 @@ from fastapi._compat import (
Undefined,
lenient_issubclass,
)
from fastapi.concurrency import (
iterate_in_threadpool,
run_in_teardown_threadpool,
run_in_threadpool,
)
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
@ -75,7 +80,6 @@ from fastapi.utils import (
from starlette import routing
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool
from starlette.datastructures import FormData
from starlette.exceptions import HTTPException
from starlette.requests import Request
@ -292,7 +296,7 @@ async def serialize_response(
if is_coroutine:
value, errors = field.validate(response_content, {}, loc=("response",))
else:
value, errors = await run_in_threadpool(
value, errors = await run_in_teardown_threadpool(
field.validate, response_content, {}, loc=("response",)
)
if errors:

68
tests/test_concurrency.py

@ -0,0 +1,68 @@
import contextlib
import time
from collections.abc import Iterator
import anyio.to_thread
import pytest
from anyio import CapacityLimiter
from fastapi import concurrency
@pytest.fixture
def reset_teardown_limiter(monkeypatch: pytest.MonkeyPatch) -> None:
"""Reset the teardown limiter before/after tests to avoid interference
between different anyio backends."""
monkeypatch.setattr(concurrency, "_teardown_limiter", CapacityLimiter(5))
@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_run_in_teardown_threadpool() -> None:
def func(x: int, y: int) -> int:
return x + y
result = await concurrency.run_in_teardown_threadpool(func, 1, y=2)
assert result == 3
@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_contextmanager_in_threadpool() -> None:
@contextlib.contextmanager
def context_manager() -> Iterator[str]:
yield "entered"
async with concurrency.contextmanager_in_threadpool(context_manager()) as result:
assert result == "entered"
@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_competing_acquire_release() -> None:
"""Check that the main threadpool does not block the teardown threadpool."""
pool_size = anyio.to_thread.current_default_thread_limiter().total_tokens
acquirable = False
acquired = []
def acquire() -> None:
while not acquirable:
time.sleep(0.001)
acquired.append(True)
def release() -> bool:
nonlocal acquirable
time.sleep(0.001)
acquirable = True
return acquirable
async with anyio.create_task_group() as tg:
for _ in range(pool_size):
tg.start_soon(concurrency.run_in_threadpool, acquire)
await anyio.sleep(0.001)
# The threadpool should now be full of threads waiting to acquire
# The release function should be able to run without being blocked by acquires
await concurrency.run_in_teardown_threadpool(release)
assert len(acquired) == pool_size

56
tests/test_depends_deadlock.py

@ -0,0 +1,56 @@
import asyncio
import threading
import time
from collections.abc import Iterator
from fastapi import Depends, FastAPI
from httpx import ASGITransport, AsyncClient
from pydantic import BaseModel
# Mutex, and dependency acting as our "connection pool" for a database for example
mutex = threading.Lock()
# Simulate releaasing a pooled resource in the teardown of a Depends,
# which in reality is usually a database connection or similar.
def release_resource() -> Iterator[None]:
try:
time.sleep(0.001)
yield
finally:
time.sleep(0.001)
mutex.release()
app = FastAPI()
class Item(BaseModel):
name: str
id: int
# An endpoint that uses Depends for resource management and also includes
# a response_model definition would previously deadlock in the validation
# of the model and the cleanup of the Depends
@app.get("/deadlock", response_model=Item)
def get_deadlock(dep: None = Depends(release_resource)) -> Item:
mutex.acquire()
return Item(name="foo", id=1)
# Fire off 100 requests in parallel(ish) in order to create contention
# over the shared resource (simulating a fastapi server that interacts with
# a database connection pool).
def test_depends_deadlock() -> None:
async def make_request(client: AsyncClient):
await client.get("/deadlock")
async def run_requests() -> None:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://testserver"
) as aclient:
tasks = [make_request(aclient) for _ in range(100)]
await asyncio.gather(*tasks)
asyncio.run(run_requests())
Loading…
Cancel
Save