Browse Source

[commands] Change cooldowns to take context instead of message

pull/8265/head
Mikey 3 years ago
committed by GitHub
parent
commit
311891912e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 38
      discord/ext/commands/cooldowns.py
  2. 26
      discord/ext/commands/core.py

38
discord/ext/commands/cooldowns.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, TYPE_CHECKING from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING
from discord.enums import Enum from discord.enums import Enum
import time import time
import asyncio import asyncio
@ -33,12 +33,14 @@ from collections import deque
from ...abc import PrivateChannel from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached from .errors import MaxConcurrencyReached
from .context import Context
from discord.app_commands import Cooldown as Cooldown from discord.app_commands import Cooldown as Cooldown
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
from ...message import Message from ...message import Message
from ._types import BotT
__all__ = ( __all__ = (
'BucketType', 'BucketType',
@ -48,6 +50,8 @@ __all__ = (
'MaxConcurrency', 'MaxConcurrency',
) )
T = TypeVar('T')
class BucketType(Enum): class BucketType(Enum):
default = 0 default = 0
@ -58,7 +62,7 @@ class BucketType(Enum):
category = 5 category = 5
role = 6 role = 6
def get_key(self, msg: Message) -> Any: def get_key(self, msg: Union[Message, Context[BotT]]) -> Any:
if self is BucketType.user: if self is BucketType.user:
return msg.author.id return msg.author.id
elif self is BucketType.guild: elif self is BucketType.guild:
@ -76,22 +80,22 @@ class BucketType(Enum):
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
def __call__(self, msg: Message) -> Any: def __call__(self, msg: Union[Message, Context[BotT]]) -> Any:
return self.get_key(msg) return self.get_key(msg)
class CooldownMapping: class CooldownMapping(Generic[T]):
def __init__( def __init__(
self, self,
original: Optional[Cooldown], original: Optional[Cooldown],
type: Callable[[Message], Any], type: Callable[[T], Any],
) -> None: ) -> None:
if not callable(type): if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable') raise TypeError('Cooldown type must be a BucketType or callable')
self._cache: Dict[Any, Cooldown] = {} self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original self._cooldown: Optional[Cooldown] = original
self._type: Callable[[Message], Any] = type self._type: Callable[[T], Any] = type
def copy(self) -> CooldownMapping: def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type) ret = CooldownMapping(self._cooldown, self._type)
@ -103,14 +107,14 @@ class CooldownMapping:
return self._cooldown is not None return self._cooldown is not None
@property @property
def type(self) -> Callable[[Message], Any]: def type(self) -> Callable[[T], Any]:
return self._type return self._type
@classmethod @classmethod
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self: def from_cooldown(cls, rate: float, per: float, type: Callable[[T], Any]) -> Self:
return cls(Cooldown(rate, per), type) return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any: def _bucket_key(self, msg: T) -> Any:
return self._type(msg) return self._type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None: def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
@ -122,10 +126,10 @@ class CooldownMapping:
for k in dead_keys: for k in dead_keys:
del self._cache[k] del self._cache[k]
def create_bucket(self, message: Message) -> Cooldown: def create_bucket(self, message: T) -> Cooldown:
return self._cooldown.copy() # type: ignore return self._cooldown.copy() # type: ignore
def get_bucket(self, message: Message, current: Optional[float] = None) -> Optional[Cooldown]: def get_bucket(self, message: T, current: Optional[float] = None) -> Optional[Cooldown]:
if self._type is BucketType.default: if self._type is BucketType.default:
return self._cooldown return self._cooldown
@ -140,21 +144,21 @@ class CooldownMapping:
return bucket return bucket
def update_rate_limit(self, message: Message, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: def update_rate_limit(self, message: T, current: Optional[float] = None, tokens: int = 1) -> Optional[float]:
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
if bucket is None: if bucket is None:
return None return None
return bucket.update_rate_limit(current, tokens=tokens) return bucket.update_rate_limit(current, tokens=tokens)
class DynamicCooldownMapping(CooldownMapping): class DynamicCooldownMapping(CooldownMapping[T]):
def __init__( def __init__(
self, self,
factory: Callable[[Message], Optional[Cooldown]], factory: Callable[[T], Optional[Cooldown]],
type: Callable[[Message], Any], type: Callable[[T], Any],
) -> None: ) -> None:
super().__init__(None, type) super().__init__(None, type)
self._factory: Callable[[Message], Optional[Cooldown]] = factory self._factory: Callable[[T], Optional[Cooldown]] = factory
def copy(self) -> DynamicCooldownMapping: def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type) ret = DynamicCooldownMapping(self._factory, self._type)
@ -165,7 +169,7 @@ class DynamicCooldownMapping(CooldownMapping):
def valid(self) -> bool: def valid(self) -> bool:
return True return True
def create_bucket(self, message: Message) -> Optional[Cooldown]: def create_bucket(self, message: T) -> Optional[Cooldown]:
return self._factory(message) return self._factory(message)

26
discord/ext/commands/core.py

@ -58,8 +58,6 @@ from .parameters import Parameter, Signature
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, Self from typing_extensions import Concatenate, ParamSpec, Self
from discord.message import Message
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
@ -409,10 +407,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if cooldown is None: if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default) buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping): elif isinstance(cooldown, CooldownMapping):
buckets = cooldown buckets: CooldownMapping[Context] = cooldown
else: else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self._buckets: CooldownMapping = buckets self._buckets: CooldownMapping[Context] = buckets
try: try:
max_concurrency = func.__commands_max_concurrency__ max_concurrency = func.__commands_max_concurrency__
@ -879,7 +877,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if self._buckets.valid: if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(ctx.message, current) bucket = self._buckets.get_bucket(ctx, current)
if bucket is not None: if bucket is not None:
retry_after = bucket.update_rate_limit(current) retry_after = bucket.update_rate_limit(current)
if retry_after: if retry_after:
@ -929,7 +927,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not self._buckets.valid: if not self._buckets.valid:
return False return False
bucket = self._buckets.get_bucket(ctx.message) bucket = self._buckets.get_bucket(ctx)
if bucket is None: if bucket is None:
return False return False
dt = ctx.message.edited_at or ctx.message.created_at dt = ctx.message.edited_at or ctx.message.created_at
@ -949,7 +947,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
The invocation context to reset the cooldown under. The invocation context to reset the cooldown under.
""" """
if self._buckets.valid: if self._buckets.valid:
bucket = self._buckets.get_bucket(ctx.message) bucket = self._buckets.get_bucket(ctx)
if bucket is not None: if bucket is not None:
bucket.reset() bucket.reset()
@ -974,7 +972,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
If this is ``0.0`` then the command isn't on cooldown. If this is ``0.0`` then the command isn't on cooldown.
""" """
if self._buckets.valid: if self._buckets.valid:
bucket = self._buckets.get_bucket(ctx.message) bucket = self._buckets.get_bucket(ctx)
if bucket is None: if bucket is None:
return 0.0 return 0.0
dt = ctx.message.edited_at or ctx.message.created_at dt = ctx.message.edited_at or ctx.message.created_at
@ -2399,7 +2397,7 @@ def is_nsfw() -> Check[Any]:
def cooldown( def cooldown(
rate: int, rate: int,
per: float, per: float,
type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, type: Union[BucketType, Callable[[Context], Any]] = BucketType.default,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command` """A decorator that adds a cooldown to a :class:`.Command`
@ -2420,7 +2418,7 @@ def cooldown(
The number of times a command can be used before triggering a cooldown. The number of times a command can be used before triggering a cooldown.
per: :class:`float` per: :class:`float`
The amount of seconds to wait for a cooldown when it's been triggered. The amount of seconds to wait for a cooldown when it's been triggered.
type: Union[:class:`.BucketType`, Callable[[:class:`.Message`], Any]] type: Union[:class:`.BucketType`, Callable[[:class:`.Context`], Any]]
The type of cooldown to have. If callable, should return a key for the mapping. The type of cooldown to have. If callable, should return a key for the mapping.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
@ -2431,15 +2429,15 @@ def cooldown(
if isinstance(func, Command): if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per), type) func._buckets = CooldownMapping(Cooldown(rate, per), type)
else: else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) # type: ignore # typevar cannot be inferred without annotation
return func return func
return decorator # type: ignore return decorator # type: ignore
def dynamic_cooldown( def dynamic_cooldown(
cooldown: Union[BucketType, Callable[[Message], Any]], cooldown: Callable[[Context], Cooldown | None],
type: BucketType, type: BucketType | Callable[[Context], Any],
) -> Callable[[T], T]: ) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command` """A decorator that adds a dynamic cooldown to a :class:`.Command`
@ -2463,7 +2461,7 @@ def dynamic_cooldown(
Parameters Parameters
------------ ------------
cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`~discord.app_commands.Cooldown`]] cooldown: Callable[[:class:`.Context`], Optional[:class:`~discord.app_commands.Cooldown`]]
A function that takes a message and returns a cooldown that will A function that takes a message and returns a cooldown that will
apply to this invocation or ``None`` if the cooldown should be bypassed. apply to this invocation or ``None`` if the cooldown should be bypassed.
type: :class:`.BucketType` type: :class:`.BucketType`

Loading…
Cancel
Save