diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 5e7f2aa39..fe763fb62 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE. from discord.enums import Enum import time +import asyncio +from collections import deque from ...abc import PrivateChannel +from .errors import MaxConcurrencyReached __all__ = ( 'BucketType', 'Cooldown', 'CooldownMapping', + 'MaxConcurrency', ) class BucketType(Enum): @@ -163,3 +167,129 @@ class CooldownMapping: def update_rate_limit(self, message, current=None): bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) + +class _Semaphore: + """This class is a version of a semaphore. + + If you're wondering why asyncio.Semaphore isn't being used, + it's because it doesn't expose the internal value. This internal + value is necessary because I need to support both `wait=True` and + `wait=False`. + + An asyncio.Queue could have been used to do this as well -- but it + not as inefficient since internally that uses two queues and is a bit + overkill for what is basically a counter. + """ + + __slots__ = ('value', 'loop', '_waiters') + + def __init__(self, number): + self.value = number + self.loop = asyncio.get_event_loop() + self._waiters = deque() + + def __repr__(self): + return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters)) + + def locked(self): + return self.value == 0 + + def wake_up(self): + while self._waiters: + future = self._waiters.popleft() + if not future.done(): + future.set_result(None) + return + + async def acquire(self, *, wait=False): + if not wait and self.value <= 0: + # signal that we're not acquiring + return False + + while self.value <= 0: + future = self.loop.create_future() + self._waiters.append(future) + try: + await future + except: + future.cancel() + if self.value > 0 and not future.cancelled(): + self.wake_up() + raise + + self.value -= 1 + return True + + def release(self): + self.value += 1 + self.wake_up() + +class MaxConcurrency: + __slots__ = ('number', 'per', 'wait', '_mapping') + + def __init__(self, number, *, per, wait): + self._mapping = {} + self.per = per + self.number = number + self.wait = wait + + if number <= 0: + raise ValueError('max_concurrency \'number\' cannot be less than 1') + + if not isinstance(per, BucketType): + raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per)) + + def copy(self): + return self.__class__(self.number, per=self.per, wait=self.wait) + + def __repr__(self): + return ''.format(self) + + def get_bucket(self, message): + bucket_type = self.per + if bucket_type is BucketType.default: + return 'global' + elif bucket_type is BucketType.user: + return message.author.id + elif bucket_type is BucketType.guild: + return (message.guild or message.author).id + elif bucket_type is BucketType.channel: + return message.channel.id + elif bucket_type is BucketType.member: + return ((message.guild and message.guild.id), message.author.id) + elif bucket_type is BucketType.category: + return (message.channel.category or message.channel).id + elif bucket_type is BucketType.role: + # we return the channel id of a private-channel as there are only roles in guilds + # and that yields the same result as for a guild with only the @everyone role + # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are + # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do + return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id + + async def acquire(self, message): + key = self.get_bucket(message) + + try: + sem = self._mapping[key] + except KeyError: + self._mapping[key] = sem = _Semaphore(self.number) + + acquired = await sem.acquire(wait=self.wait) + if not acquired: + raise MaxConcurrencyReached(self.number, self.per) + + async def release(self, message): + # Technically there's no reason for this function to be async + # But it might be more useful in the future + key = self.get_bucket(message) + + try: + sem = self._mapping[key] + except KeyError: + # ...? peculiar + return + else: + sem.release() + + if sem.value >= self.number: + del self._mapping[key] diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index aa4343f5c..e587d3a72 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -33,7 +33,7 @@ import datetime import discord from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping +from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency from . import converter as converters from ._types import _BaseCommand from .cog import Cog @@ -53,6 +53,7 @@ __all__ = ( 'bot_has_permissions', 'bot_has_any_role', 'cooldown', + 'max_concurrency', 'dm_only', 'guild_only', 'is_owner', @@ -90,6 +91,9 @@ def hooked_wrapped_callback(command, ctx, coro): ctx.command_failed = True raise CommandInvokeError(exc) from exc finally: + if command._max_concurrency is not None: + await command._max_concurrency.release(ctx) + await command.call_after_hooks(ctx) return ret return wrapped @@ -248,6 +252,13 @@ class Command(_BaseCommand): finally: self._buckets = CooldownMapping(cooldown) + try: + max_concurrency = func.__commands_max_concurrency__ + except AttributeError: + max_concurrency = kwargs.get('max_concurrency') + finally: + self._max_concurrency = max_concurrency + self.ignore_extra = kwargs.get('ignore_extra', True) self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False) self.cog = None @@ -331,6 +342,9 @@ class Command(_BaseCommand): other.checks = self.checks.copy() if self._buckets.valid and not other._buckets.valid: other._buckets = self._buckets.copy() + if self._max_concurrency != other._max_concurrency: + other._max_concurrency = self._max_concurrency.copy() + try: other.on_error = self.on_error except AttributeError: @@ -718,6 +732,9 @@ class Command(_BaseCommand): self._prepare_cooldowns(ctx) await self._parse_arguments(ctx) + if self._max_concurrency is not None: + await self._max_concurrency.acquire(ctx) + await self.call_before_hooks(ctx) def is_on_cooldown(self, ctx): @@ -1800,3 +1817,36 @@ def cooldown(rate, per, type=BucketType.default): func.__commands_cooldown__ = Cooldown(rate, per, type) return func return decorator + +def max_concurrency(number, per=BucketType.default, *, wait=False): + """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. + + This enables you to only allow a certain number of command invocations at the same time, + for example if a command takes too long or if only one user can use it at a time. This + differs from a cooldown in that there is no set waiting period or token bucket -- only + a set number of people can run the command. + + .. versionadded:: 1.3.0 + + Parameters + ------------- + number: :class:`int` + The maximum number of invocations of this command that can be running at the same time. + per: :class:`.BucketType` + The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow + it to be used up to ``number`` times per guild. + wait: :class:`bool` + Whether the command should wait for the queue to be over. If this is set to ``False`` + then instead of waiting until the command can run again, the command raises + :exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True`` + then the command waits until it can be executed. + """ + + def decorator(func): + value = MaxConcurrency(number, per=per, wait=wait) + if isinstance(func, Command): + func._max_concurrency = value + else: + func.__commands_max_concurrency__ = value + return func + return decorator diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 6087c1dfd..0d5e0d0f1 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -41,6 +41,7 @@ __all__ = ( 'TooManyArguments', 'UserInputError', 'CommandOnCooldown', + 'MaxConcurrencyReached', 'NotOwner', 'MissingRole', 'BotMissingRole', @@ -240,6 +241,28 @@ class CommandOnCooldown(CommandError): self.retry_after = retry_after super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after)) +class MaxConcurrencyReached(CommandError): + """Exception raised when the command being invoked has reached its maximum concurrency. + + This inherits from :exc:`CommandError`. + + Attributes + ------------ + number: :class:`int` + The maximum number of concurrent invokers allowed. + per: :class:`BucketType` + The bucket type passed to the :func:`.max_concurrency` decorator. + """ + + def __init__(self, number, per): + self.number = number + self.per = per + name = per.name + suffix = 'per %s' % name if per.name != 'default' else 'globally' + plural = '%s times %s' if number > 1 else '%s time %s' + fmt = plural % (number, suffix) + super().__init__('Too many people using this command. It can only be used {}.'.format(fmt)) + class MissingRole(CheckFailure): """Exception raised when the command invoker lacks a role to run a command.