diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py index ed28b6a06..d3b64a29e 100644 --- a/discord/ext/commands/__init__.py +++ b/discord/ext/commands/__init__.py @@ -16,3 +16,4 @@ from .core import * from .errors import * from .formatter import HelpFormatter, Paginator from .converter import * +from .cooldowns import BucketType diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py new file mode 100644 index 000000000..035ac809e --- /dev/null +++ b/discord/ext/commands/cooldowns.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import enum +import time + +__all__ = ['BucketType', 'Cooldown', 'CooldownMapping'] + +class BucketType(enum.Enum): + default = 0 + user = 1 + server = 2 + channel = 3 + +class Cooldown: + __slots__ = ['rate', 'per', 'type', '_window', '_tokens', '_last'] + + def __init__(self, rate, per, type): + self.rate = int(rate) + self.per = float(per) + self.type = type + self._window = 0.0 + self._tokens = self.rate + + if not isinstance(self.type, BucketType): + raise TypeError('Cooldown type must be a BucketType') + + def is_rate_limited(self): + current = time.time() + self._last = current + + # first token used means that we start a new rate limit window + if self._tokens == self.rate: + self._window = current + + # check if our window has passed and we can refresh our tokens + if current > self._window + self.per: + self._tokens = self.rate + self._window = current + + # check if we're rate limited + if self._tokens == 0: + return self.per - (current - self._window) + + # we're not so decrement our tokens + self._tokens -= 1 + + # see if we got rate limited due to this token change, and if + # so update the window to point to our current time frame + if self._tokens == 0: + self._window = current + + def copy(self): + return Cooldown(self.rate, self.per, self.type) + + def __repr__(self): + return ''.format(self) + +class CooldownMapping: + def __init__(self, original): + self._cache = {} + self._cooldown = original + + @property + def valid(self): + return self._cooldown is not None + + def _bucket_key(self, ctx): + msg = ctx.message + bucket_type = self._cooldown.type + if bucket_type is BucketType.user: + return msg.author.id + elif bucket_type is BucketType.server: + return getattr(msg.server, 'id', msg.author.id) + elif bucket_type is BucketType.channel: + return msg.channel.id + + def _verify_cache_integrity(self): + # we want to delete all cache objects that haven't been used + # in a cooldown window. e.g. if we have a command that has a + # cooldown of 60s and it has not been used in 60s then that key should be deleted + current = time.time() + dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per] + for k in dead_keys: + del self._cache[k] + + def get_bucket(self, ctx): + if self._cooldown.type is BucketType.default: + return self._cooldown + + self._verify_cache_integrity() + key = self._bucket_key(ctx) + if key not in self._cache: + bucket = self._cooldown.copy() + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index fff90ae6d..8303181b8 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -30,12 +30,14 @@ import discord import functools from .errors import * +from .cooldowns import Cooldown, BucketType, CooldownMapping from .view import quoted_word from . import converter as converters __all__ = [ 'Command', 'Group', 'GroupMixin', 'command', 'group', 'has_role', 'has_permissions', 'has_any_role', 'check', - 'bot_has_role', 'bot_has_permissions', 'bot_has_any_role' ] + 'bot_has_role', 'bot_has_permissions', 'bot_has_any_role', + 'cooldown' ] def inject_context(ctx, coro): @functools.wraps(coro) @@ -142,6 +144,7 @@ class Command: self.ignore_extra = kwargs.get('ignore_extra', True) self.instance = None self.parent = None + self._buckets = CooldownMapping(kwargs.get('cooldown')) def dispatch_error(self, error, ctx): try: @@ -328,6 +331,12 @@ class Command: if not self.can_run(ctx): raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) + retry_after = bucket.is_rate_limited() + if retry_after: + raise CommandOnCooldown(bucket, retry_after) + @asyncio.coroutine def invoke(self, ctx): ctx.command = self @@ -637,6 +646,12 @@ def command(name=None, cls=None, **attrs): except AttributeError: checks = [] + try: + cooldown = func.__commands_cooldown__ + del func.__commands_cooldown__ + except AttributeError: + cooldown = None + help_doc = attrs.get('help') if help_doc is not None: help_doc = inspect.cleandoc(help_doc) @@ -647,7 +662,7 @@ def command(name=None, cls=None, **attrs): attrs['help'] = help_doc fname = name or func.__name__ - return cls(name=fname, callback=func, checks=checks, **attrs) + return cls(name=fname, callback=func, checks=checks, cooldown=cooldown, **attrs) return decorator @@ -848,3 +863,41 @@ def bot_has_permissions(**perms): permissions = ch.permissions_for(me) return all(getattr(permissions, perm, None) == value for perm, value in perms.items()) return check(predicate) + +def cooldown(rate, per, type=BucketType.default): + """A decorator that adds a cooldown to a :class:`Command` + or its subclasses. + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns can be based + either on a per-server, per-channel, per-user, or global basis. + Denoted by the third argument of ``type`` which must be of enum + type ``BucketType`` which could be either: + + - ``BucketType.default`` for a global basis. + - ``BucketType.user`` for a per-user basis. + - ``BucketType.server`` for a per-server basis. + - ``BucketType.channel`` for a per-channel basis. + + If a cooldown is triggered, then :exc:`CommandOnCooldown` is triggered in + :func:`on_command_error` and the local error handler. + + A command can only have a single cooldown. + + Parameters + ------------ + rate: int + The number of times a command can be used before triggering a cooldown. + per: float + The amount of seconds to wait for a cooldown when it's been triggered. + type: ``BucketType`` + The type of cooldown to have. + """ + + def decorator(func): + if isinstance(func, Command): + func.cooldown = Cooldown(rate, per, type) + else: + func.__commands_cooldown__ = Cooldown(rate, per, type) + return func + return decorator diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 1cdb0652f..b42b87e92 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -29,7 +29,7 @@ from discord.errors import DiscordException __all__ = [ 'CommandError', 'MissingRequiredArgument', 'BadArgument', 'NoPrivateMessage', 'CheckFailure', 'CommandNotFound', 'DisabledCommand', 'CommandInvokeError', 'TooManyArguments', - 'UserInputError' ] + 'UserInputError', 'CommandOnCooldown' ] class CommandError(DiscordException): """The base exception type for all command related errors. @@ -110,3 +110,18 @@ class CommandInvokeError(CommandError): self.original = e super().__init__('Command raised an exception: {0.__class__.__name__}: {0}'.format(e)) +class CommandOnCooldown(CommandError): + """Exception raised when the command being invoked is on cooldown. + + Attributes + ----------- + cooldown: Cooldown + A class with attributes ``rate``, ``per``, and ``type`` similar to + the :func:`cooldown` decorator. + retry_after: float + The amount of seconds to wait before you can retry again. + """ + def __init__(self, cooldown, retry_after): + self.cooldown = cooldown + self.retry_after = retry_after + super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after))