|
|
@ -26,13 +26,17 @@ DEALINGS IN THE SOFTWARE. |
|
|
|
|
|
|
|
from discord.enums import Enum |
|
|
|
import time |
|
|
|
import asyncio |
|
|
|
from collections import deque |
|
|
|
|
|
|
|
from ...abc import PrivateChannel |
|
|
|
from .errors import MaxConcurrencyReached |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'BucketType', |
|
|
|
'Cooldown', |
|
|
|
'CooldownMapping', |
|
|
|
'MaxConcurrency', |
|
|
|
) |
|
|
|
|
|
|
|
class BucketType(Enum): |
|
|
@ -163,3 +167,129 @@ class CooldownMapping: |
|
|
|
def update_rate_limit(self, message, current=None): |
|
|
|
bucket = self.get_bucket(message, current) |
|
|
|
return bucket.update_rate_limit(current) |
|
|
|
|
|
|
|
class _Semaphore: |
|
|
|
"""This class is a version of a semaphore. |
|
|
|
|
|
|
|
If you're wondering why asyncio.Semaphore isn't being used, |
|
|
|
it's because it doesn't expose the internal value. This internal |
|
|
|
value is necessary because I need to support both `wait=True` and |
|
|
|
`wait=False`. |
|
|
|
|
|
|
|
An asyncio.Queue could have been used to do this as well -- but it |
|
|
|
not as inefficient since internally that uses two queues and is a bit |
|
|
|
overkill for what is basically a counter. |
|
|
|
""" |
|
|
|
|
|
|
|
__slots__ = ('value', 'loop', '_waiters') |
|
|
|
|
|
|
|
def __init__(self, number): |
|
|
|
self.value = number |
|
|
|
self.loop = asyncio.get_event_loop() |
|
|
|
self._waiters = deque() |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters)) |
|
|
|
|
|
|
|
def locked(self): |
|
|
|
return self.value == 0 |
|
|
|
|
|
|
|
def wake_up(self): |
|
|
|
while self._waiters: |
|
|
|
future = self._waiters.popleft() |
|
|
|
if not future.done(): |
|
|
|
future.set_result(None) |
|
|
|
return |
|
|
|
|
|
|
|
async def acquire(self, *, wait=False): |
|
|
|
if not wait and self.value <= 0: |
|
|
|
# signal that we're not acquiring |
|
|
|
return False |
|
|
|
|
|
|
|
while self.value <= 0: |
|
|
|
future = self.loop.create_future() |
|
|
|
self._waiters.append(future) |
|
|
|
try: |
|
|
|
await future |
|
|
|
except: |
|
|
|
future.cancel() |
|
|
|
if self.value > 0 and not future.cancelled(): |
|
|
|
self.wake_up() |
|
|
|
raise |
|
|
|
|
|
|
|
self.value -= 1 |
|
|
|
return True |
|
|
|
|
|
|
|
def release(self): |
|
|
|
self.value += 1 |
|
|
|
self.wake_up() |
|
|
|
|
|
|
|
class MaxConcurrency: |
|
|
|
__slots__ = ('number', 'per', 'wait', '_mapping') |
|
|
|
|
|
|
|
def __init__(self, number, *, per, wait): |
|
|
|
self._mapping = {} |
|
|
|
self.per = per |
|
|
|
self.number = number |
|
|
|
self.wait = wait |
|
|
|
|
|
|
|
if number <= 0: |
|
|
|
raise ValueError('max_concurrency \'number\' cannot be less than 1') |
|
|
|
|
|
|
|
if not isinstance(per, BucketType): |
|
|
|
raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per)) |
|
|
|
|
|
|
|
def copy(self): |
|
|
|
return self.__class__(self.number, per=self.per, wait=self.wait) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self) |
|
|
|
|
|
|
|
def get_bucket(self, message): |
|
|
|
bucket_type = self.per |
|
|
|
if bucket_type is BucketType.default: |
|
|
|
return 'global' |
|
|
|
elif bucket_type is BucketType.user: |
|
|
|
return message.author.id |
|
|
|
elif bucket_type is BucketType.guild: |
|
|
|
return (message.guild or message.author).id |
|
|
|
elif bucket_type is BucketType.channel: |
|
|
|
return message.channel.id |
|
|
|
elif bucket_type is BucketType.member: |
|
|
|
return ((message.guild and message.guild.id), message.author.id) |
|
|
|
elif bucket_type is BucketType.category: |
|
|
|
return (message.channel.category or message.channel).id |
|
|
|
elif bucket_type is BucketType.role: |
|
|
|
# we return the channel id of a private-channel as there are only roles in guilds |
|
|
|
# and that yields the same result as for a guild with only the @everyone role |
|
|
|
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are |
|
|
|
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do |
|
|
|
return (message.channel if isinstance(message.channel, PrivateChannel) else message.author.top_role).id |
|
|
|
|
|
|
|
async def acquire(self, message): |
|
|
|
key = self.get_bucket(message) |
|
|
|
|
|
|
|
try: |
|
|
|
sem = self._mapping[key] |
|
|
|
except KeyError: |
|
|
|
self._mapping[key] = sem = _Semaphore(self.number) |
|
|
|
|
|
|
|
acquired = await sem.acquire(wait=self.wait) |
|
|
|
if not acquired: |
|
|
|
raise MaxConcurrencyReached(self.number, self.per) |
|
|
|
|
|
|
|
async def release(self, message): |
|
|
|
# Technically there's no reason for this function to be async |
|
|
|
# But it might be more useful in the future |
|
|
|
key = self.get_bucket(message) |
|
|
|
|
|
|
|
try: |
|
|
|
sem = self._mapping[key] |
|
|
|
except KeyError: |
|
|
|
# ...? peculiar |
|
|
|
return |
|
|
|
else: |
|
|
|
sem.release() |
|
|
|
|
|
|
|
if sem.value >= self.number: |
|
|
|
del self._mapping[key] |
|
|
|