From 32a58515b96a0b8e118e8c507286b08a1055142f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:58:50 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20format?= =?UTF-8?q?=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/concurrency.py | 16 +++++++++------- fastapi/dependencies/models.py | 2 +- fastapi/dependencies/utils.py | 19 +++++++++++++------ fastapi/param_functions.py | 7 ++++--- fastapi/params.py | 13 +++++++++---- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/fastapi/concurrency.py b/fastapi/concurrency.py index 83bb9e157..132ce4199 100644 --- a/fastapi/concurrency.py +++ b/fastapi/concurrency.py @@ -1,8 +1,9 @@ -from contextlib import asynccontextmanager as asynccontextmanager -from typing import AsyncGenerator, ContextManager, TypeVar, Optional import functools import sys import typing +from contextlib import asynccontextmanager as asynccontextmanager +from typing import AsyncGenerator, ContextManager, Optional, TypeVar + if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover @@ -18,20 +19,23 @@ from starlette.concurrency import ( # noqa _P = ParamSpec("_P") _T = TypeVar("_T") + async def run_in_threadpool( func: typing.Callable[_P, _T], *args: typing.Any, _limiter: Optional[anyio.CapacityLimiter] = None, - **kwargs: typing.Any + **kwargs: typing.Any, ) -> _T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await anyio.to_thread.run_sync(func, *args, limiter=_limiter) + @asynccontextmanager async def contextmanager_in_threadpool( - cm: ContextManager[_T], limiter: Optional[anyio.CapacityLimiter] = None, + cm: ContextManager[_T], + limiter: Optional[anyio.CapacityLimiter] = None, ) -> AsyncGenerator[_T, None]: # blocking __exit__ from running waiting on a free thread # can create race conditions/deadlocks if the context manager itself @@ -51,6 +55,4 @@ async def contextmanager_in_threadpool( if not ok: raise e else: - await run_in_threadpool( - cm.__exit__, None, None, None, _limiter=exit_limiter - ) + await run_in_threadpool(cm.__exit__, None, None, None, _limiter=exit_limiter) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 0891213e8..a152cf212 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Sequence, Tuple -import anyio +import anyio from fastapi._compat import ModelField from fastapi.security.base import SecurityBase diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 96a36de1d..72db9f421 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -533,12 +533,16 @@ def is_gen_callable(call: Callable[..., Any]) -> bool: async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any], + *, + call: Callable[..., Any], + stack: AsyncExitStack, + sub_values: Dict[str, Any], limiter: Optional[anyio.CapacityLimiter] = None, ) -> Any: if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values), - limiter=limiter) + cm = contextmanager_in_threadpool( + contextmanager(call)(**sub_values), limiter=limiter + ) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) @@ -615,14 +619,17 @@ async def solve_dependencies( solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values, + call=call, + stack=async_exit_stack, + sub_values=solved_result.values, limiter=sub_dependant.limiter, ) elif is_coroutine_callable(call): solved = await call(**solved_result.values) else: - solved = await run_in_threadpool(call, _limiter=sub_dependant.limiter, - **solved_result.values) + solved = await run_in_threadpool( + call, _limiter=sub_dependant.limiter, **solved_result.values + ) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 86c3c22e8..7722f3b4c 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import anyio +import anyio from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example @@ -2384,5 +2384,6 @@ def Security( # noqa: N802 return [{"item_id": "Foo", "owner": current_user.username}] ``` """ - return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache, - limiter=limiter) + return params.Security( + dependency=dependency, scopes=scopes, use_cache=use_cache, limiter=limiter + ) diff --git a/fastapi/params.py b/fastapi/params.py index 41494044f..a936e9892 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -1,8 +1,8 @@ import warnings from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import anyio +import anyio from fastapi.openapi.models import Example from pydantic.fields import FieldInfo from typing_extensions import Annotated, deprecated @@ -761,7 +761,9 @@ class File(Form): class Depends: def __init__( - self, dependency: Optional[Callable[..., Any]] = None, *, + self, + dependency: Optional[Callable[..., Any]] = None, + *, use_cache: bool = True, limiter: Optional[anyio.CapacityLimiter] = None, ): @@ -772,8 +774,11 @@ class Depends: def __repr__(self) -> str: attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) cache = "" if self.use_cache else ", use_cache=False" - limiter = f", limiter=CapacityLimiter({self.limiter.total_tokens})" \ - if self.limiter else "" + limiter = ( + f", limiter=CapacityLimiter({self.limiter.total_tokens})" + if self.limiter + else "" + ) return f"{self.__class__.__name__}({attr}{cache}{limiter})"