From 311891912e2de597ed44abc959b74f6fb2059d3c Mon Sep 17 00:00:00 2001 From: Mikey <8661717+sgtlaggy@users.noreply.github.com> Date: Sat, 23 Jul 2022 04:08:44 -0700 Subject: [PATCH] [commands] Change cooldowns to take context instead of message --- discord/ext/commands/cooldowns.py | 38 +++++++++++++++++-------------- discord/ext/commands/core.py | 26 ++++++++++----------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 26a0f5600..1a332370c 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, Callable, Deque, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING from discord.enums import Enum import time import asyncio @@ -33,12 +33,14 @@ from collections import deque from ...abc import PrivateChannel from .errors import MaxConcurrencyReached +from .context import Context from discord.app_commands import Cooldown as Cooldown if TYPE_CHECKING: from typing_extensions import Self from ...message import Message + from ._types import BotT __all__ = ( 'BucketType', @@ -48,6 +50,8 @@ __all__ = ( 'MaxConcurrency', ) +T = TypeVar('T') + class BucketType(Enum): default = 0 @@ -58,7 +62,7 @@ class BucketType(Enum): category = 5 role = 6 - def get_key(self, msg: Message) -> Any: + def get_key(self, msg: Union[Message, Context[BotT]]) -> Any: if self is BucketType.user: return msg.author.id elif self is BucketType.guild: @@ -76,22 +80,22 @@ class BucketType(Enum): # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore - def __call__(self, msg: Message) -> Any: + def __call__(self, msg: Union[Message, Context[BotT]]) -> Any: return self.get_key(msg) -class CooldownMapping: +class CooldownMapping(Generic[T]): def __init__( self, original: Optional[Cooldown], - type: Callable[[Message], Any], + type: Callable[[T], Any], ) -> None: if not callable(type): raise TypeError('Cooldown type must be a BucketType or callable') self._cache: Dict[Any, Cooldown] = {} self._cooldown: Optional[Cooldown] = original - self._type: Callable[[Message], Any] = type + self._type: Callable[[T], Any] = type def copy(self) -> CooldownMapping: ret = CooldownMapping(self._cooldown, self._type) @@ -103,14 +107,14 @@ class CooldownMapping: return self._cooldown is not None @property - def type(self) -> Callable[[Message], Any]: + def type(self) -> Callable[[T], Any]: return self._type @classmethod - def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self: + def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self: return cls(Cooldown(rate, per), type) - def _bucket_key(self, msg: Message) -> Any: + def _bucket_key(self, msg: T) -> Any: return self._type(msg) def _verify_cache_integrity(self, current: Optional[float] = None) -> None: @@ -122,10 +126,10 @@ class CooldownMapping: for k in dead_keys: del self._cache[k] - def create_bucket(self, message: Message) -> Cooldown: + def create_bucket(self, message: T) -> Cooldown: return self._cooldown.copy() # type: ignore - def get_bucket(self, message: Message, current: Optional[float] = None) -> Optional[Cooldown]: + def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]: if self._type is BucketType.default: return self._cooldown @@ -140,21 +144,21 @@ class CooldownMapping: return bucket - def update_rate_limit(self, message: Message, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: + def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: bucket = self.get_bucket(message, current) if bucket is None: return None return bucket.update_rate_limit(current, tokens=tokens) -class DynamicCooldownMapping(CooldownMapping): +class DynamicCooldownMapping(CooldownMapping[T]): def __init__( self, - factory: Callable[[Message], Optional[Cooldown]], - type: Callable[[Message], Any], + factory: Callable[[T], Optional[Cooldown]], + type: Callable[[T], Any], ) -> None: super().__init__(None, type) - self._factory: Callable[[Message], Optional[Cooldown]] = factory + self._factory: Callable[[T], Optional[Cooldown]] = factory def copy(self) -> DynamicCooldownMapping: ret = DynamicCooldownMapping(self._factory, self._type) @@ -165,7 +169,7 @@ class DynamicCooldownMapping(CooldownMapping): def valid(self) -> bool: return True - def create_bucket(self, message: Message) -> Optional[Cooldown]: + def create_bucket(self, message: T) -> Optional[Cooldown]: return self._factory(message) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index e76edf39f..32c0b0dd1 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -58,8 +58,6 @@ from .parameters import Parameter, Signature if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, Self - from discord.message import Message - from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck @@ -409,10 +407,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): - buckets = cooldown + buckets: CooldownMapping[Context] = cooldown else: raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") - self._buckets: CooldownMapping = buckets + self._buckets: CooldownMapping[Context] = buckets try: max_concurrency = func.__commands_max_concurrency__ @@ -879,7 +877,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if self._buckets.valid: dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() - bucket = self._buckets.get_bucket(ctx.message, current) + bucket = self._buckets.get_bucket(ctx, current) if bucket is not None: retry_after = bucket.update_rate_limit(current) if retry_after: @@ -929,7 +927,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if not self._buckets.valid: return False - bucket = self._buckets.get_bucket(ctx.message) + bucket = self._buckets.get_bucket(ctx) if bucket is None: return False dt = ctx.message.edited_at or ctx.message.created_at @@ -949,7 +947,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): The invocation context to reset the cooldown under. """ if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx.message) + bucket = self._buckets.get_bucket(ctx) if bucket is not None: bucket.reset() @@ -974,7 +972,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): If this is ``0.0`` then the command isn't on cooldown. """ if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx.message) + bucket = self._buckets.get_bucket(ctx) if bucket is None: return 0.0 dt = ctx.message.edited_at or ctx.message.created_at @@ -2399,7 +2397,7 @@ def is_nsfw() -> Check[Any]: def cooldown( rate: int, per: float, - type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, + type: Union[BucketType, Callable[[Context], Any]] = BucketType.default, ) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` @@ -2420,7 +2418,7 @@ def cooldown( The number of times a command can be used before triggering a cooldown. per: :class:`float` The amount of seconds to wait for a cooldown when it's been triggered. - type: Union[:class:`.BucketType`, Callable[[:class:`.Message`], Any]] + type: Union[:class:`.BucketType`, Callable[[:class:`.Context`], Any]] The type of cooldown to have. If callable, should return a key for the mapping. .. versionchanged:: 1.7 @@ -2431,15 +2429,15 @@ def cooldown( if isinstance(func, Command): func._buckets = CooldownMapping(Cooldown(rate, per), type) else: - func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) + func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) # type: ignore # typevar cannot be inferred without annotation return func return decorator # type: ignore def dynamic_cooldown( - cooldown: Union[BucketType, Callable[[Message], Any]], - type: BucketType, + cooldown: Callable[[Context], Cooldown | None], + type: BucketType | Callable[[Context], Any], ) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` @@ -2463,7 +2461,7 @@ def dynamic_cooldown( Parameters ------------ - cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`~discord.app_commands.Cooldown`]] + cooldown: Callable[[:class:`.Context`], Optional[:class:`~discord.app_commands.Cooldown`]] A function that takes a message and returns a cooldown that will apply to this invocation or ``None`` if the cooldown should be bypassed. type: :class:`.BucketType`