From d61486278ff72f0c31bcea14dc3173ebdc1e8850 Mon Sep 17 00:00:00 2001 From: Mikey <8661717+sgtlaggy@users.noreply.github.com> Date: Sun, 28 Mar 2021 03:31:51 -0700 Subject: [PATCH] [commands] allow arbitrary callables in cooldown --- discord/ext/commands/cooldowns.py | 9 ++++++--- discord/ext/commands/core.py | 7 +++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index cd7d67ea8..54a533961 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -66,6 +66,9 @@ class BucketType(Enum): # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id + def __call__(self, msg): + return self.get_key(msg) + class Cooldown: __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') @@ -78,8 +81,8 @@ class Cooldown: self._tokens = self.rate self._last = 0.0 - if not isinstance(self.type, BucketType): - raise TypeError('Cooldown type must be a BucketType') + if not callable(self.type): + raise TypeError('Cooldown type must be a BucketType or callable') def get_tokens(self, current=None): if not current: @@ -151,7 +154,7 @@ class CooldownMapping: return cls(Cooldown(rate, per, type)) def _bucket_key(self, msg): - return self._cooldown.type.get_key(msg) + return self._cooldown.type(msg) def _verify_cache_integrity(self, current=None): # we want to delete all cache objects that haven't been used diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 40c7eaf82..1c22ec0a1 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -1959,8 +1959,11 @@ def cooldown(rate, per, type=BucketType.default): The number of times a command can be used before triggering a cooldown. per: :class:`float` The amount of seconds to wait for a cooldown when it's been triggered. - type: :class:`.BucketType` - The type of cooldown to have. + type: Union[:class:`.BucketType`, Callable[[:class:`.Message`], Any]] + The type of cooldown to have. If callable, should return a key for the mapping. + + .. versionchanged:: 1.7 + Callables are now supported for custom bucket types. """ def decorator(func):