Browse Source

[commands] Add max_concurrency decorator

pull/2519/head
Rapptz 5 years ago
parent
commit
bf84c63396
  1. 130
      discord/ext/commands/cooldowns.py
  2. 52
      discord/ext/commands/core.py
  3. 23
      discord/ext/commands/errors.py

130
discord/ext/commands/cooldowns.py

@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE.
from discord.enums import Enum from discord.enums import Enum
import time import time
import asyncio
from collections import deque
from ...abc import PrivateChannel from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
__all__ = ( __all__ = (
'BucketType', 'BucketType',
'Cooldown', 'Cooldown',
'CooldownMapping', 'CooldownMapping',
'MaxConcurrency',
) )
class BucketType(Enum): class BucketType(Enum):
@ -163,3 +167,129 @@ class CooldownMapping:
def update_rate_limit(self, message, current=None): def update_rate_limit(self, message, current=None):
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current) return bucket.update_rate_limit(current)
class _Semaphore:
"""This class is a version of a semaphore.
If you're wondering why asyncio.Semaphore isn't being used,
it's because it doesn't expose the internal value. This internal
value is necessary because I need to support both `wait=True` and
`wait=False`.
An asyncio.Queue could have been used to do this as well -- but it
not as inefficient since internally that uses two queues and is a bit
overkill for what is basically a counter.
"""
__slots__ = ('value', 'loop', '_waiters')
def __init__(self, number):
self.value = number
self.loop = asyncio.get_event_loop()
self._waiters = deque()
def __repr__(self):
return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
def locked(self):
return self.value == 0
def wake_up(self):
while self._waiters:
future = self._waiters.popleft()
if not future.done():
future.set_result(None)
return
async def acquire(self, *, wait=False):
if not wait and self.value <= 0:
# signal that we're not acquiring
return False
while self.value <= 0:
future = self.loop.create_future()
self._waiters.append(future)
try:
await future
except:
future.cancel()
if self.value > 0 and not future.cancelled():
self.wake_up()
raise
self.value -= 1
return True
def release(self):
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
def __init__(self, number, *, per, wait):
self._mapping = {}
self.per = per
self.number = number
self.wait = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
if not isinstance(per, BucketType):
raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per))
def copy(self):
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self):
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
def get_bucket(self, message):
bucket_type = self.per
if bucket_type is BucketType.default:
return 'global'
elif bucket_type is BucketType.user:
return message.author.id
elif bucket_type is BucketType.guild:
return (message.guild or message.author).id
elif bucket_type is BucketType.channel:
return message.channel.id
elif bucket_type is BucketType.member:
return ((message.guild and message.guild.id), message.author.id)
elif bucket_type is BucketType.category:
return (message.channel.category or message.channel).id
elif bucket_type is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id
async def acquire(self, message):
key = self.get_bucket(message)
try:
sem = self._mapping[key]
except KeyError:
self._mapping[key] = sem = _Semaphore(self.number)
acquired = await sem.acquire(wait=self.wait)
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message):
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_bucket(message)
try:
sem = self._mapping[key]
except KeyError:
# ...? peculiar
return
else:
sem.release()
if sem.value >= self.number:
del self._mapping[key]

52
discord/ext/commands/core.py

@ -33,7 +33,7 @@ import datetime
import discord import discord
from .errors import * from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency
from . import converter as converters from . import converter as converters
from ._types import _BaseCommand from ._types import _BaseCommand
from .cog import Cog from .cog import Cog
@ -53,6 +53,7 @@ __all__ = (
'bot_has_permissions', 'bot_has_permissions',
'bot_has_any_role', 'bot_has_any_role',
'cooldown', 'cooldown',
'max_concurrency',
'dm_only', 'dm_only',
'guild_only', 'guild_only',
'is_owner', 'is_owner',
@ -90,6 +91,9 @@ def hooked_wrapped_callback(command, ctx, coro):
ctx.command_failed = True ctx.command_failed = True
raise CommandInvokeError(exc) from exc raise CommandInvokeError(exc) from exc
finally: finally:
if command._max_concurrency is not None:
await command._max_concurrency.release(ctx)
await command.call_after_hooks(ctx) await command.call_after_hooks(ctx)
return ret return ret
return wrapped return wrapped
@ -248,6 +252,13 @@ class Command(_BaseCommand):
finally: finally:
self._buckets = CooldownMapping(cooldown) self._buckets = CooldownMapping(cooldown)
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
finally:
self._max_concurrency = max_concurrency
self.ignore_extra = kwargs.get('ignore_extra', True) self.ignore_extra = kwargs.get('ignore_extra', True)
self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False) self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False)
self.cog = None self.cog = None
@ -331,6 +342,9 @@ class Command(_BaseCommand):
other.checks = self.checks.copy() other.checks = self.checks.copy()
if self._buckets.valid and not other._buckets.valid: if self._buckets.valid and not other._buckets.valid:
other._buckets = self._buckets.copy() other._buckets = self._buckets.copy()
if self._max_concurrency != other._max_concurrency:
other._max_concurrency = self._max_concurrency.copy()
try: try:
other.on_error = self.on_error other.on_error = self.on_error
except AttributeError: except AttributeError:
@ -718,6 +732,9 @@ class Command(_BaseCommand):
self._prepare_cooldowns(ctx) self._prepare_cooldowns(ctx)
await self._parse_arguments(ctx) await self._parse_arguments(ctx)
if self._max_concurrency is not None:
await self._max_concurrency.acquire(ctx)
await self.call_before_hooks(ctx) await self.call_before_hooks(ctx)
def is_on_cooldown(self, ctx): def is_on_cooldown(self, ctx):
@ -1800,3 +1817,36 @@ def cooldown(rate, per, type=BucketType.default):
func.__commands_cooldown__ = Cooldown(rate, per, type) func.__commands_cooldown__ = Cooldown(rate, per, type)
return func return func
return decorator return decorator
def max_concurrency(number, per=BucketType.default, *, wait=False):
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
This enables you to only allow a certain number of command invocations at the same time,
for example if a command takes too long or if only one user can use it at a time. This
differs from a cooldown in that there is no set waiting period or token bucket -- only
a set number of people can run the command.
.. versionadded:: 1.3.0
Parameters
-------------
number: :class:`int`
The maximum number of invocations of this command that can be running at the same time.
per: :class:`.BucketType`
The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow
it to be used up to ``number`` times per guild.
wait: :class:`bool`
Whether the command should wait for the queue to be over. If this is set to ``False``
then instead of waiting until the command can run again, the command raises
:exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True``
then the command waits until it can be executed.
"""
def decorator(func):
value = MaxConcurrency(number, per=per, wait=wait)
if isinstance(func, Command):
func._max_concurrency = value
else:
func.__commands_max_concurrency__ = value
return func
return decorator

23
discord/ext/commands/errors.py

@ -41,6 +41,7 @@ __all__ = (
'TooManyArguments', 'TooManyArguments',
'UserInputError', 'UserInputError',
'CommandOnCooldown', 'CommandOnCooldown',
'MaxConcurrencyReached',
'NotOwner', 'NotOwner',
'MissingRole', 'MissingRole',
'BotMissingRole', 'BotMissingRole',
@ -240,6 +241,28 @@ class CommandOnCooldown(CommandError):
self.retry_after = retry_after self.retry_after = retry_after
super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after)) super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after))
class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency.
This inherits from :exc:`CommandError`.
Attributes
------------
number: :class:`int`
The maximum number of concurrent invokers allowed.
per: :class:`BucketType`
The bucket type passed to the :func:`.max_concurrency` decorator.
"""
def __init__(self, number, per):
self.number = number
self.per = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
fmt = plural % (number, suffix)
super().__init__('Too many people using this command. It can only be used {}.'.format(fmt))
class MissingRole(CheckFailure): class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command. """Exception raised when the command invoker lacks a role to run a command.

Loading…
Cancel
Save