Browse Source

Merge dd0209864f into 6df50d40fe

pull/11895/merge
Leo Bergolth 4 days ago
committed by GitHub
parent
commit
3324148823
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 35
      fastapi/concurrency.py
  2. 2
      fastapi/dependencies/models.py
  3. 26
      fastapi/dependencies/utils.py
  4. 33
      fastapi/param_functions.py
  5. 18
      fastapi/params.py
  6. 130
      tests/test_depends_limiter.py

35
fastapi/concurrency.py

@ -1,20 +1,41 @@
import functools
import sys
import typing
from contextlib import asynccontextmanager as asynccontextmanager
from typing import AsyncGenerator, ContextManager, TypeVar
from typing import AsyncGenerator, ContextManager, Optional, TypeVar
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
import anyio.to_thread
from anyio import CapacityLimiter
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)
_P = ParamSpec("_P")
_T = TypeVar("_T")
async def run_in_threadpool(
func: typing.Callable[_P, _T],
*args: typing.Any,
_limiter: Optional[anyio.CapacityLimiter] = None,
**kwargs: typing.Any,
) -> _T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args, limiter=_limiter)
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: ContextManager[_T],
limiter: Optional[anyio.CapacityLimiter] = None,
) -> AsyncGenerator[_T, None]:
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
@ -24,16 +45,14 @@ async def contextmanager_in_threadpool(
# works (1 is arbitrary)
exit_limiter = CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
yield await run_in_threadpool(cm.__enter__, _limiter=limiter)
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
await run_in_threadpool(
cm.__exit__, type(e), e, e.__traceback__, _limiter=exit_limiter
)
)
if not ok:
raise e
else:
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)
await run_in_threadpool(cm.__exit__, None, None, None, _limiter=exit_limiter)

2
fastapi/dependencies/models.py

@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple
import anyio
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
@ -32,6 +33,7 @@ class Dependant:
use_cache: bool = True
path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
limiter: Optional[anyio.CapacityLimiter] = None
def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

26
fastapi/dependencies/utils.py

@ -50,6 +50,7 @@ from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
run_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger
@ -60,7 +61,6 @@ from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import (
FormData,
Headers,
@ -165,6 +165,7 @@ def get_sub_dependant(
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
limiter=depends.limiter,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
@ -192,6 +193,7 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
limiter=dependant.limiter,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
@ -269,6 +271,7 @@ def get_dependant(
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
@ -279,6 +282,7 @@ def get_dependant(
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
limiter=limiter,
)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
@ -551,10 +555,16 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
*,
call: Callable[..., Any],
stack: AsyncExitStack,
sub_values: Dict[str, Any],
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Any:
if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
cm = contextmanager_in_threadpool(
contextmanager(call)(**sub_values), limiter=limiter
)
elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
@ -610,6 +620,7 @@ async def solve_dependencies(
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
limiter=sub_dependant.limiter,
)
solved_result = await solve_dependencies(
@ -632,12 +643,17 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values
call=call,
stack=async_exit_stack,
sub_values=solved_result.values,
limiter=sub_dependant.limiter,
)
elif is_coroutine_callable(call):
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
solved = await run_in_threadpool(
call, _limiter=sub_dependant.limiter, **solved_result.values
)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:

33
fastapi/param_functions.py

@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
@ -2244,6 +2245,20 @@ def Depends( # noqa: N802
"""
),
] = True,
limiter: Annotated[
Optional[anyio.CapacityLimiter],
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
For async dependencies (defined using `async def`) this parameter is ignored.
"""
),
] = None,
) -> Any:
"""
Declare a FastAPI dependency.
@ -2274,7 +2289,7 @@ def Depends( # noqa: N802
return commons
```
"""
return params.Depends(dependency=dependency, use_cache=use_cache)
return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter)
def Security( # noqa: N802
@ -2321,6 +2336,18 @@ def Security( # noqa: N802
"""
),
] = True,
limiter: Annotated[
Optional[anyio.CapacityLimiter],
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
"""
),
] = None,
) -> Any:
"""
Declare a FastAPI Security dependency.
@ -2357,4 +2384,6 @@ def Security( # noqa: N802
return [{"item_id": "Foo", "owner": current_user.username}]
```
"""
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
return params.Security(
dependency=dependency, scopes=scopes, use_cache=use_cache, limiter=limiter
)

18
fastapi/params.py

@ -2,6 +2,7 @@ import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated
@ -763,15 +764,25 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
):
self.dependency = dependency
self.use_cache = use_cache
self.limiter = limiter
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})"
limiter = (
f", limiter=CapacityLimiter({self.limiter.total_tokens})"
if self.limiter
else ""
)
return f"{self.__class__.__name__}({attr}{cache}{limiter})"
class Security(Depends):
@ -781,6 +792,7 @@ class Security(Depends):
*,
scopes: Optional[Sequence[str]] = None,
use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
):
super().__init__(dependency=dependency, use_cache=use_cache)
super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter)
self.scopes = scopes or []

130
tests/test_depends_limiter.py

@ -0,0 +1,130 @@
from contextlib import asynccontextmanager
import anyio
import sniffio
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from typing_extensions import Annotated
def get_borrowed_tokens():
return {lname: lobj.borrowed_tokens for lname, lobj in sorted(limiters.items())}
def limited_dep():
# run this in the event loop thread:
tokens = anyio.from_thread.run_sync(get_borrowed_tokens)
yield tokens
def init_limiters():
return {
"default": None, # should be set in Lifespan handler
"a": anyio.CapacityLimiter(5),
"b": anyio.CapacityLimiter(3),
}
# Note:
# initializing CapacityLimiters at module level before the event loop has started
# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0
# see https://github.com/agronholm/anyio/pull/651
# The following is a temporary workaround for anyio < 4.2.0:
try:
limiters = init_limiters()
except sniffio.AsyncLibraryNotFoundError:
# not in an async context yet
async def _init_limiters():
return init_limiters()
limiters = anyio.run(_init_limiters)
@asynccontextmanager
async def lifespan(app: FastAPI):
limiters["default"] = anyio.to_thread.current_default_thread_limiter()
yield {
"limiters": limiters,
}
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]):
return borrowed_tokens
@app.get("/a")
async def a(
borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["a"])],
):
return borrowed_tokens
@app.get("/b")
async def b(
borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["b"])],
):
return borrowed_tokens
def test_depends_limiter():
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"a": 0, "b": 0, "default": 1}
response = client.get("/a")
assert response.status_code == 200, response.text
assert response.json() == {"a": 1, "b": 0, "default": 0}
response = client.get("/b")
assert response.status_code == 200, response.text
assert response.json() == {"a": 0, "b": 1, "default": 0}
def test_openapi_schema():
with TestClient(app) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json()["paths"] == {
"/": {
"get": {
"summary": "Root",
"operationId": "root__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
"/a": {
"get": {
"summary": "A",
"operationId": "a_a_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
"/b": {
"get": {
"summary": "B",
"operationId": "b_b_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
}
Loading…
Cancel
Save