diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index fc438c9f7..cb0f75cfb 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -34,6 +34,7 @@ __all__ = ( 'BucketType', 'Cooldown', 'CooldownMapping', + 'DynamicCooldownMapping', 'MaxConcurrency', ) @@ -69,19 +70,15 @@ class BucketType(Enum): class Cooldown: - __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') + __slots__ = ('rate', 'per', '_window', '_tokens', '_last') - def __init__(self, rate, per, type): + def __init__(self, rate, per): self.rate = int(rate) self.per = float(per) - self.type = type self._window = 0.0 self._tokens = self.rate self._last = 0.0 - if not callable(self.type): - raise TypeError('Cooldown type must be a BucketType or callable') - def get_tokens(self, current=None): if not current: current = time.time() @@ -128,15 +125,19 @@ class Cooldown: self._last = 0.0 def copy(self): - return Cooldown(self.rate, self.per, self.type) + return Cooldown(self.rate, self.per) def __repr__(self): return f'' class CooldownMapping: - def __init__(self, original): + def __init__(self, original, type): + if not callable(type): + raise TypeError('Cooldown type must be a BucketType or callable') + self._cache = {} self._cooldown = original + self._type = type def copy(self): ret = CooldownMapping(self._cooldown) @@ -152,7 +153,7 @@ class CooldownMapping: return cls(Cooldown(rate, per, type)) def _bucket_key(self, msg): - return self._cooldown.type(msg) + return self._type(msg) def _verify_cache_integrity(self, current=None): # we want to delete all cache objects that haven't been used @@ -163,15 +164,19 @@ class CooldownMapping: for k in dead_keys: del self._cache[k] + def create_bucket(self, message): + return self._cooldown.copy() + def get_bucket(self, message, current=None): - if self._cooldown.type is BucketType.default: + if self._type is BucketType.default: return self._cooldown self._verify_cache_integrity(current) key = self._bucket_key(message) if key not in self._cache: - bucket = self._cooldown.copy() - self._cache[key] = bucket + bucket = self.create_bucket(message) + if bucket is not None: + self._cache[key] = bucket else: bucket = self._cache[key] @@ -181,6 +186,19 @@ class CooldownMapping: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) +class DynamicCooldownMapping(CooldownMapping): + + def __init__(self, factory, type): + super().__init__(None, type) + self._factory = factory + + @property + def valid(self): + return True + + def create_bucket(self, message): + return self._factory(message) + class _Semaphore: """This class is a version of a semaphore. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index cce6f30cc..d735c77ba 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -32,7 +32,7 @@ import sys import discord from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency +from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping from . import converter as converters from ._types import _BaseCommand from .cog import Cog @@ -54,6 +54,7 @@ __all__ = ( 'bot_has_permissions', 'bot_has_any_role', 'cooldown', + 'dynamic_cooldown', 'max_concurrency', 'dm_only', 'guild_only', @@ -256,7 +257,10 @@ class Command(_BaseCommand): except AttributeError: cooldown = kwargs.get('cooldown') finally: - self._buckets = CooldownMapping(cooldown) + if cooldown is None: + self._buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + self._buckets = cooldown try: max_concurrency = func.__commands_max_concurrency__ @@ -799,9 +803,10 @@ class Command(_BaseCommand): dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() bucket = self._buckets.get_bucket(ctx.message, current) - retry_after = bucket.update_rate_limit(current) - if retry_after: - raise CommandOnCooldown(bucket, retry_after) + if bucket is not None: + retry_after = bucket.update_rate_limit(current) + if retry_after: + raise CommandOnCooldown(bucket, retry_after) async def prepare(self, ctx): ctx.command = self @@ -2014,9 +2019,48 @@ def cooldown(rate, per, type=BucketType.default): def decorator(func): if isinstance(func, Command): - func._buckets = CooldownMapping(Cooldown(rate, per, type)) + func._buckets = CooldownMapping(Cooldown(rate, per), type) else: - func.__commands_cooldown__ = Cooldown(rate, per, type) + func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) + return func + return decorator + +def dynamic_cooldown(cooldown, type=BucketType.default): + """A decorator that adds a dynamic cooldown to a :class:`.Command` + + This differs from :func:`.cooldown` in that it takes a function that + accepts a single parameter of type :class:`.discord.Message` and must + return a :class:`.Cooldown` + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns can be based + either on a per-guild, per-channel, per-user, per-role or global basis. + Denoted by the third argument of ``type`` which must be of enum + type :class:`.BucketType`. + + If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in + :func:`.on_command_error` and the local error handler. + + A command can only have a single cooldown. + + .. versionadded:: 2.0 + + Parameters + ------------ + cooldown: Callable[[:class:`.discord.Message`], :class:`.Cooldown`] + A function that takes a message and returns a cooldown that will + apply to this invocation + type: :class:`.BucketType` + The type of cooldown to have. + """ + if not callable(cooldown): + raise TypeError("A callable must be provided") + + def decorator(func): + if isinstance(func, Command): + func._buckets = DynamicCooldownMapping(cooldown, type) + else: + func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) return func return decorator