Browse Source

[commands] Provide a dynamic cooldown system

pull/6692/head
Dan Hess 4 years ago
committed by GitHub
parent
commit
f2d5ab6f80
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 42
      discord/ext/commands/cooldowns.py
  2. 58
      discord/ext/commands/core.py

42
discord/ext/commands/cooldowns.py

@ -34,6 +34,7 @@ __all__ = (
'BucketType', 'BucketType',
'Cooldown', 'Cooldown',
'CooldownMapping', 'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency', 'MaxConcurrency',
) )
@ -69,19 +70,15 @@ class BucketType(Enum):
class Cooldown: class Cooldown:
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') __slots__ = ('rate', 'per', '_window', '_tokens', '_last')
def __init__(self, rate, per, type): def __init__(self, rate, per):
self.rate = int(rate) self.rate = int(rate)
self.per = float(per) self.per = float(per)
self.type = type
self._window = 0.0 self._window = 0.0
self._tokens = self.rate self._tokens = self.rate
self._last = 0.0 self._last = 0.0
if not callable(self.type):
raise TypeError('Cooldown type must be a BucketType or callable')
def get_tokens(self, current=None): def get_tokens(self, current=None):
if not current: if not current:
current = time.time() current = time.time()
@ -128,15 +125,19 @@ class Cooldown:
self._last = 0.0 self._last = 0.0
def copy(self): def copy(self):
return Cooldown(self.rate, self.per, self.type) return Cooldown(self.rate, self.per)
def __repr__(self): def __repr__(self):
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>' return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping: class CooldownMapping:
def __init__(self, original): def __init__(self, original, type):
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache = {} self._cache = {}
self._cooldown = original self._cooldown = original
self._type = type
def copy(self): def copy(self):
ret = CooldownMapping(self._cooldown) ret = CooldownMapping(self._cooldown)
@ -152,7 +153,7 @@ class CooldownMapping:
return cls(Cooldown(rate, per, type)) return cls(Cooldown(rate, per, type))
def _bucket_key(self, msg): def _bucket_key(self, msg):
return self._cooldown.type(msg) return self._type(msg)
def _verify_cache_integrity(self, current=None): def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used # we want to delete all cache objects that haven't been used
@ -163,15 +164,19 @@ class CooldownMapping:
for k in dead_keys: for k in dead_keys:
del self._cache[k] del self._cache[k]
def create_bucket(self, message):
return self._cooldown.copy()
def get_bucket(self, message, current=None): def get_bucket(self, message, current=None):
if self._cooldown.type is BucketType.default: if self._type is BucketType.default:
return self._cooldown return self._cooldown
self._verify_cache_integrity(current) self._verify_cache_integrity(current)
key = self._bucket_key(message) key = self._bucket_key(message)
if key not in self._cache: if key not in self._cache:
bucket = self._cooldown.copy() bucket = self.create_bucket(message)
self._cache[key] = bucket if bucket is not None:
self._cache[key] = bucket
else: else:
bucket = self._cache[key] bucket = self._cache[key]
@ -181,6 +186,19 @@ class CooldownMapping:
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current) return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory, type):
super().__init__(None, type)
self._factory = factory
@property
def valid(self):
return True
def create_bucket(self, message):
return self._factory(message)
class _Semaphore: class _Semaphore:
"""This class is a version of a semaphore. """This class is a version of a semaphore.

58
discord/ext/commands/core.py

@ -32,7 +32,7 @@ import sys
import discord import discord
from .errors import * from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
from . import converter as converters from . import converter as converters
from ._types import _BaseCommand from ._types import _BaseCommand
from .cog import Cog from .cog import Cog
@ -54,6 +54,7 @@ __all__ = (
'bot_has_permissions', 'bot_has_permissions',
'bot_has_any_role', 'bot_has_any_role',
'cooldown', 'cooldown',
'dynamic_cooldown',
'max_concurrency', 'max_concurrency',
'dm_only', 'dm_only',
'guild_only', 'guild_only',
@ -256,7 +257,10 @@ class Command(_BaseCommand):
except AttributeError: except AttributeError:
cooldown = kwargs.get('cooldown') cooldown = kwargs.get('cooldown')
finally: finally:
self._buckets = CooldownMapping(cooldown) if cooldown is None:
self._buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
self._buckets = cooldown
try: try:
max_concurrency = func.__commands_max_concurrency__ max_concurrency = func.__commands_max_concurrency__
@ -799,9 +803,10 @@ class Command(_BaseCommand):
dt = ctx.message.edited_at or ctx.message.created_at dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(ctx.message, current) bucket = self._buckets.get_bucket(ctx.message, current)
retry_after = bucket.update_rate_limit(current) if bucket is not None:
if retry_after: retry_after = bucket.update_rate_limit(current)
raise CommandOnCooldown(bucket, retry_after) if retry_after:
raise CommandOnCooldown(bucket, retry_after)
async def prepare(self, ctx): async def prepare(self, ctx):
ctx.command = self ctx.command = self
@ -2014,9 +2019,48 @@ def cooldown(rate, per, type=BucketType.default):
def decorator(func): def decorator(func):
if isinstance(func, Command): if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per, type)) func._buckets = CooldownMapping(Cooldown(rate, per), type)
else: else:
func.__commands_cooldown__ = Cooldown(rate, per, type) func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
return func
return decorator
def dynamic_cooldown(cooldown, type=BucketType.default):
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
This differs from :func:`.cooldown` in that it takes a function that
accepts a single parameter of type :class:`.discord.Message` and must
return a :class:`.Cooldown`
A cooldown allows a command to only be used a specific amount
of times in a specific time frame. These cooldowns can be based
either on a per-guild, per-channel, per-user, per-role or global basis.
Denoted by the third argument of ``type`` which must be of enum
type :class:`.BucketType`.
If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in
:func:`.on_command_error` and the local error handler.
A command can only have a single cooldown.
.. versionadded:: 2.0
Parameters
------------
cooldown: Callable[[:class:`.discord.Message`], :class:`.Cooldown`]
A function that takes a message and returns a cooldown that will
apply to this invocation
type: :class:`.BucketType`
The type of cooldown to have.
"""
if not callable(cooldown):
raise TypeError("A callable must be provided")
def decorator(func):
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type)
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
return func return func
return decorator return decorator

Loading…
Cancel
Save