Browse Source

[commands] Provide a dynamic cooldown system

pull/6692/head
Dan Hess 4 years ago
committed by GitHub
parent
commit
f2d5ab6f80
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 42
      discord/ext/commands/cooldowns.py
  2. 58
      discord/ext/commands/core.py

42
discord/ext/commands/cooldowns.py

@ -34,6 +34,7 @@ __all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency',
)
@ -69,19 +70,15 @@ class BucketType(Enum):
class Cooldown:
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
def __init__(self, rate, per, type):
def __init__(self, rate, per):
self.rate = int(rate)
self.per = float(per)
self.type = type
self._window = 0.0
self._tokens = self.rate
self._last = 0.0
if not callable(self.type):
raise TypeError('Cooldown type must be a BucketType or callable')
def get_tokens(self, current=None):
if not current:
current = time.time()
@ -128,15 +125,19 @@ class Cooldown:
self._last = 0.0
def copy(self):
return Cooldown(self.rate, self.per, self.type)
return Cooldown(self.rate, self.per)
def __repr__(self):
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping:
def __init__(self, original):
def __init__(self, original, type):
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache = {}
self._cooldown = original
self._type = type
def copy(self):
ret = CooldownMapping(self._cooldown)
@ -152,7 +153,7 @@ class CooldownMapping:
return cls(Cooldown(rate, per, type))
def _bucket_key(self, msg):
return self._cooldown.type(msg)
return self._type(msg)
def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used
@ -163,15 +164,19 @@ class CooldownMapping:
for k in dead_keys:
del self._cache[k]
def create_bucket(self, message):
return self._cooldown.copy()
def get_bucket(self, message, current=None):
if self._cooldown.type is BucketType.default:
if self._type is BucketType.default:
return self._cooldown
self._verify_cache_integrity(current)
key = self._bucket_key(message)
if key not in self._cache:
bucket = self._cooldown.copy()
self._cache[key] = bucket
bucket = self.create_bucket(message)
if bucket is not None:
self._cache[key] = bucket
else:
bucket = self._cache[key]
@ -181,6 +186,19 @@ class CooldownMapping:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory, type):
super().__init__(None, type)
self._factory = factory
@property
def valid(self):
return True
def create_bucket(self, message):
return self._factory(message)
class _Semaphore:
"""This class is a version of a semaphore.

58
discord/ext/commands/core.py

@ -32,7 +32,7 @@ import sys
import discord
from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
from . import converter as converters
from ._types import _BaseCommand
from .cog import Cog
@ -54,6 +54,7 @@ __all__ = (
'bot_has_permissions',
'bot_has_any_role',
'cooldown',
'dynamic_cooldown',
'max_concurrency',
'dm_only',
'guild_only',
@ -256,7 +257,10 @@ class Command(_BaseCommand):
except AttributeError:
cooldown = kwargs.get('cooldown')
finally:
self._buckets = CooldownMapping(cooldown)
if cooldown is None:
self._buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
self._buckets = cooldown
try:
max_concurrency = func.__commands_max_concurrency__
@ -799,9 +803,10 @@ class Command(_BaseCommand):
dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(ctx.message, current)
retry_after = bucket.update_rate_limit(current)
if retry_after:
raise CommandOnCooldown(bucket, retry_after)
if bucket is not None:
retry_after = bucket.update_rate_limit(current)
if retry_after:
raise CommandOnCooldown(bucket, retry_after)
async def prepare(self, ctx):
ctx.command = self
@ -2014,9 +2019,48 @@ def cooldown(rate, per, type=BucketType.default):
def decorator(func):
if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per, type))
func._buckets = CooldownMapping(Cooldown(rate, per), type)
else:
func.__commands_cooldown__ = Cooldown(rate, per, type)
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator
def dynamic_cooldown(cooldown, type=BucketType.default):
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
This differs from :func:`.cooldown` in that it takes a function that
accepts a single parameter of type :class:`.discord.Message` and must
return a :class:`.Cooldown`
A cooldown allows a command to only be used a specific amount
of times in a specific time frame. These cooldowns can be based
either on a per-guild, per-channel, per-user, per-role or global basis.
Denoted by the third argument of ``type`` which must be of enum
type :class:`.BucketType`.
If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in
:func:`.on_command_error` and the local error handler.
A command can only have a single cooldown.
.. versionadded:: 2.0
Parameters
------------
cooldown: Callable[[:class:`.discord.Message`], :class:`.Cooldown`]
A function that takes a message and returns a cooldown that will
apply to this invocation
type: :class:`.BucketType`
The type of cooldown to have.
"""
if not callable(cooldown):
raise TypeError("A callable must be provided")
def decorator(func):
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type)
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
return func
return decorator

Loading…
Cancel
Save