Browse Source

Merge dd0209864f into 6df50d40fe

pull/11895/merge
Leo Bergolth 4 days ago
committed by GitHub
parent
commit
3324148823
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 35
      fastapi/concurrency.py
  2. 2
      fastapi/dependencies/models.py
  3. 26
      fastapi/dependencies/utils.py
  4. 33
      fastapi/param_functions.py
  5. 18
      fastapi/params.py
  6. 130
      tests/test_depends_limiter.py

35
fastapi/concurrency.py

@ -1,20 +1,41 @@
import functools
import sys
import typing
from contextlib import asynccontextmanager as asynccontextmanager from contextlib import asynccontextmanager as asynccontextmanager
from typing import AsyncGenerator, ContextManager, TypeVar from typing import AsyncGenerator, ContextManager, Optional, TypeVar
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
import anyio.to_thread import anyio.to_thread
from anyio import CapacityLimiter from anyio import CapacityLimiter
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete, run_until_first_complete as run_until_first_complete,
) )
_P = ParamSpec("_P")
_T = TypeVar("_T") _T = TypeVar("_T")
async def run_in_threadpool(
func: typing.Callable[_P, _T],
*args: typing.Any,
_limiter: Optional[anyio.CapacityLimiter] = None,
**kwargs: typing.Any,
) -> _T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args, limiter=_limiter)
@asynccontextmanager @asynccontextmanager
async def contextmanager_in_threadpool( async def contextmanager_in_threadpool(
cm: ContextManager[_T], cm: ContextManager[_T],
limiter: Optional[anyio.CapacityLimiter] = None,
) -> AsyncGenerator[_T, None]: ) -> AsyncGenerator[_T, None]:
# blocking __exit__ from running waiting on a free thread # blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself # can create race conditions/deadlocks if the context manager itself
@ -24,16 +45,14 @@ async def contextmanager_in_threadpool(
# works (1 is arbitrary) # works (1 is arbitrary)
exit_limiter = CapacityLimiter(1) exit_limiter = CapacityLimiter(1)
try: try:
yield await run_in_threadpool(cm.__enter__) yield await run_in_threadpool(cm.__enter__, _limiter=limiter)
except Exception as e: except Exception as e:
ok = bool( ok = bool(
await anyio.to_thread.run_sync( await run_in_threadpool(
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter cm.__exit__, type(e), e, e.__traceback__, _limiter=exit_limiter
) )
) )
if not ok: if not ok:
raise e raise e
else: else:
await anyio.to_thread.run_sync( await run_in_threadpool(cm.__exit__, None, None, None, _limiter=exit_limiter)
cm.__exit__, None, None, None, limiter=exit_limiter
)

2
fastapi/dependencies/models.py

@ -1,6 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple from typing import Any, Callable, List, Optional, Sequence, Tuple
import anyio
from fastapi._compat import ModelField from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
@ -32,6 +33,7 @@ class Dependant:
use_cache: bool = True use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
limiter: Optional[anyio.CapacityLimiter] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

26
fastapi/dependencies/utils.py

@ -50,6 +50,7 @@ from fastapi.background import BackgroundTasks
from fastapi.concurrency import ( from fastapi.concurrency import (
asynccontextmanager, asynccontextmanager,
contextmanager_in_threadpool, contextmanager_in_threadpool,
run_in_threadpool,
) )
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.logger import logger from fastapi.logger import logger
@ -60,7 +61,6 @@ from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import ( from starlette.datastructures import (
FormData, FormData,
Headers, Headers,
@ -165,6 +165,7 @@ def get_sub_dependant(
name=name, name=name,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=depends.use_cache, use_cache=depends.use_cache,
limiter=depends.limiter,
) )
if security_requirement: if security_requirement:
sub_dependant.security_requirements.append(security_requirement) sub_dependant.security_requirements.append(security_requirement)
@ -192,6 +193,7 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(), body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(), security_requirements=dependant.security_requirements.copy(),
use_cache=dependant.use_cache, use_cache=dependant.use_cache,
limiter=dependant.limiter,
path=dependant.path, path=dependant.path,
) )
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
@ -269,6 +271,7 @@ def get_dependant(
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Dependant: ) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
@ -279,6 +282,7 @@ def get_dependant(
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=use_cache, use_cache=use_cache,
limiter=limiter,
) )
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names is_path_param = param_name in path_param_names
@ -551,10 +555,16 @@ def is_gen_callable(call: Callable[..., Any]) -> bool:
async def solve_generator( async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] *,
call: Callable[..., Any],
stack: AsyncExitStack,
sub_values: Dict[str, Any],
limiter: Optional[anyio.CapacityLimiter] = None,
) -> Any: ) -> Any:
if is_gen_callable(call): if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) cm = contextmanager_in_threadpool(
contextmanager(call)(**sub_values), limiter=limiter
)
elif is_async_gen_callable(call): elif is_async_gen_callable(call):
cm = asynccontextmanager(call)(**sub_values) cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm) return await stack.enter_async_context(cm)
@ -610,6 +620,7 @@ async def solve_dependencies(
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
limiter=sub_dependant.limiter,
) )
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
@ -632,12 +643,17 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependant.cache_key] solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator( solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=solved_result.values call=call,
stack=async_exit_stack,
sub_values=solved_result.values,
limiter=sub_dependant.limiter,
) )
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
solved = await call(**solved_result.values) solved = await call(**solved_result.values)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) solved = await run_in_threadpool(
call, _limiter=sub_dependant.limiter, **solved_result.values
)
if sub_dependant.name is not None: if sub_dependant.name is not None:
values[sub_dependant.name] = solved values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache: if sub_dependant.cache_key not in dependency_cache:

33
fastapi/param_functions.py

@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi import params from fastapi import params
from fastapi._compat import Undefined from fastapi._compat import Undefined
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
@ -2244,6 +2245,20 @@ def Depends( # noqa: N802
""" """
), ),
] = True, ] = True,
limiter: Annotated[
Optional[anyio.CapacityLimiter],
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
For async dependencies (defined using `async def`) this parameter is ignored.
"""
),
] = None,
) -> Any: ) -> Any:
""" """
Declare a FastAPI dependency. Declare a FastAPI dependency.
@ -2274,7 +2289,7 @@ def Depends( # noqa: N802
return commons return commons
``` ```
""" """
return params.Depends(dependency=dependency, use_cache=use_cache) return params.Depends(dependency=dependency, use_cache=use_cache, limiter=limiter)
def Security( # noqa: N802 def Security( # noqa: N802
@ -2321,6 +2336,18 @@ def Security( # noqa: N802
""" """
), ),
] = True, ] = True,
limiter: Annotated[
Optional[anyio.CapacityLimiter],
Doc(
"""
By default, synchronous dependencies will be run in a threadpool
with the number of concurrent threads limited by the current default anyio
thread limiter. A different `anyio.CapacityLimiter` may be specified
for problematic dependencies to use a different (logical) thread pool with
other limits in order to avoid blocking other threads.
"""
),
] = None,
) -> Any: ) -> Any:
""" """
Declare a FastAPI Security dependency. Declare a FastAPI Security dependency.
@ -2357,4 +2384,6 @@ def Security( # noqa: N802
return [{"item_id": "Foo", "owner": current_user.username}] return [{"item_id": "Foo", "owner": current_user.username}]
``` ```
""" """
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) return params.Security(
dependency=dependency, scopes=scopes, use_cache=use_cache, limiter=limiter
)

18
fastapi/params.py

@ -2,6 +2,7 @@ import warnings
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import anyio
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated from typing_extensions import Annotated, deprecated
@ -763,15 +764,25 @@ class File(Form):
class Depends: class Depends:
def __init__( def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
): ):
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
self.limiter = limiter
def __repr__(self) -> str: def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False" cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({attr}{cache})" limiter = (
f", limiter=CapacityLimiter({self.limiter.total_tokens})"
if self.limiter
else ""
)
return f"{self.__class__.__name__}({attr}{cache}{limiter})"
class Security(Depends): class Security(Depends):
@ -781,6 +792,7 @@ class Security(Depends):
*, *,
scopes: Optional[Sequence[str]] = None, scopes: Optional[Sequence[str]] = None,
use_cache: bool = True, use_cache: bool = True,
limiter: Optional[anyio.CapacityLimiter] = None,
): ):
super().__init__(dependency=dependency, use_cache=use_cache) super().__init__(dependency=dependency, use_cache=use_cache, limiter=limiter)
self.scopes = scopes or [] self.scopes = scopes or []

130
tests/test_depends_limiter.py

@ -0,0 +1,130 @@
from contextlib import asynccontextmanager
import anyio
import sniffio
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from typing_extensions import Annotated
def get_borrowed_tokens():
return {lname: lobj.borrowed_tokens for lname, lobj in sorted(limiters.items())}
def limited_dep():
# run this in the event loop thread:
tokens = anyio.from_thread.run_sync(get_borrowed_tokens)
yield tokens
def init_limiters():
return {
"default": None, # should be set in Lifespan handler
"a": anyio.CapacityLimiter(5),
"b": anyio.CapacityLimiter(3),
}
# Note:
# initializing CapacityLimiters at module level before the event loop has started
# needs anyio >= 4.2.0; starlette currently requires anyio >= 3.4.0
# see https://github.com/agronholm/anyio/pull/651
# The following is a temporary workaround for anyio < 4.2.0:
try:
limiters = init_limiters()
except sniffio.AsyncLibraryNotFoundError:
# not in an async context yet
async def _init_limiters():
return init_limiters()
limiters = anyio.run(_init_limiters)
@asynccontextmanager
async def lifespan(app: FastAPI):
limiters["default"] = anyio.to_thread.current_default_thread_limiter()
yield {
"limiters": limiters,
}
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def root(borrowed_tokens: Annotated[dict, Depends(limited_dep)]):
return borrowed_tokens
@app.get("/a")
async def a(
borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["a"])],
):
return borrowed_tokens
@app.get("/b")
async def b(
borrowed_tokens: Annotated[dict, Depends(limited_dep, limiter=limiters["b"])],
):
return borrowed_tokens
def test_depends_limiter():
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"a": 0, "b": 0, "default": 1}
response = client.get("/a")
assert response.status_code == 200, response.text
assert response.json() == {"a": 1, "b": 0, "default": 0}
response = client.get("/b")
assert response.status_code == 200, response.text
assert response.json() == {"a": 0, "b": 1, "default": 0}
def test_openapi_schema():
with TestClient(app) as client:
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json()["paths"] == {
"/": {
"get": {
"summary": "Root",
"operationId": "root__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
"/a": {
"get": {
"summary": "A",
"operationId": "a_a_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
"/b": {
"get": {
"summary": "B",
"operationId": "b_b_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
}
},
}
Loading…
Cancel
Save