|
|
@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING |
|
|
|
from discord.enums import Enum |
|
|
|
import time |
|
|
|
import asyncio |
|
|
@ -30,6 +34,9 @@ from collections import deque |
|
|
|
from ...abc import PrivateChannel |
|
|
|
from .errors import MaxConcurrencyReached |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from ...message import Message |
|
|
|
|
|
|
|
__all__ = ( |
|
|
|
'BucketType', |
|
|
|
'Cooldown', |
|
|
@ -38,6 +45,9 @@ __all__ = ( |
|
|
|
'MaxConcurrency', |
|
|
|
) |
|
|
|
|
|
|
|
C = TypeVar('C', bound='CooldownMapping') |
|
|
|
MC = TypeVar('MC', bound='MaxConcurrency') |
|
|
|
|
|
|
|
class BucketType(Enum): |
|
|
|
default = 0 |
|
|
|
user = 1 |
|
|
@ -47,7 +57,7 @@ class BucketType(Enum): |
|
|
|
category = 5 |
|
|
|
role = 6 |
|
|
|
|
|
|
|
def get_key(self, msg): |
|
|
|
def get_key(self, msg: Message) -> Any: |
|
|
|
if self is BucketType.user: |
|
|
|
return msg.author.id |
|
|
|
elif self is BucketType.guild: |
|
|
@ -57,29 +67,52 @@ class BucketType(Enum): |
|
|
|
elif self is BucketType.member: |
|
|
|
return ((msg.guild and msg.guild.id), msg.author.id) |
|
|
|
elif self is BucketType.category: |
|
|
|
return (msg.channel.category or msg.channel).id |
|
|
|
return (msg.channel.category or msg.channel).id # type: ignore |
|
|
|
elif self is BucketType.role: |
|
|
|
# we return the channel id of a private-channel as there are only roles in guilds |
|
|
|
# and that yields the same result as for a guild with only the @everyone role |
|
|
|
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are |
|
|
|
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do |
|
|
|
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id |
|
|
|
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore |
|
|
|
|
|
|
|
def __call__(self, msg): |
|
|
|
def __call__(self, msg: Message) -> Any: |
|
|
|
return self.get_key(msg) |
|
|
|
|
|
|
|
|
|
|
|
class Cooldown: |
|
|
|
__slots__ = ('rate', 'per', '_window', '_tokens', '_last') |
|
|
|
"""Represents a cooldown for a command. |
|
|
|
|
|
|
|
Attributes |
|
|
|
----------- |
|
|
|
rate: :class:`int` |
|
|
|
The total number of tokens available per :attr:`per` seconds. |
|
|
|
per: :class:`float` |
|
|
|
The length of the cooldown period in seconds. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, rate, per): |
|
|
|
self.rate = int(rate) |
|
|
|
self.per = float(per) |
|
|
|
self._window = 0.0 |
|
|
|
self._tokens = self.rate |
|
|
|
self._last = 0.0 |
|
|
|
__slots__ = ('rate', 'per', '_window', '_tokens', '_last') |
|
|
|
|
|
|
|
def get_tokens(self, current=None): |
|
|
|
def __init__(self, rate: float, per: float) -> None: |
|
|
|
self.rate: int = int(rate) |
|
|
|
self.per: float = float(per) |
|
|
|
self._window: float = 0.0 |
|
|
|
self._tokens: int = self.rate |
|
|
|
self._last: float = 0.0 |
|
|
|
|
|
|
|
def get_tokens(self, current: Optional[float] = None) -> int: |
|
|
|
"""Returns the number of available tokens before rate limiting is applied. |
|
|
|
|
|
|
|
Parameters |
|
|
|
------------ |
|
|
|
current: Optional[:class:`float`] |
|
|
|
The time in seconds since Unix epoch to calculate tokens at. |
|
|
|
If not supplied then :func:`time.time()` is used. |
|
|
|
|
|
|
|
Returns |
|
|
|
-------- |
|
|
|
:class:`int` |
|
|
|
The number of tokens available before the cooldown is to be applied. |
|
|
|
""" |
|
|
|
if not current: |
|
|
|
current = time.time() |
|
|
|
|
|
|
@ -89,7 +122,20 @@ class Cooldown: |
|
|
|
tokens = self.rate |
|
|
|
return tokens |
|
|
|
|
|
|
|
def get_retry_after(self, current=None): |
|
|
|
def get_retry_after(self, current: Optional[float] = None) -> float: |
|
|
|
"""Returns the time in seconds until the cooldown will be reset. |
|
|
|
|
|
|
|
Parameters |
|
|
|
------------- |
|
|
|
current: Optional[:class:`float`] |
|
|
|
The current time in seconds since Unix epoch. |
|
|
|
If not supplied, then :func:`time.time()` is used. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
:class:`float` |
|
|
|
The number of seconds to wait before this cooldown will be reset. |
|
|
|
""" |
|
|
|
current = current or time.time() |
|
|
|
tokens = self.get_tokens(current) |
|
|
|
|
|
|
@ -98,7 +144,20 @@ class Cooldown: |
|
|
|
|
|
|
|
return 0.0 |
|
|
|
|
|
|
|
def update_rate_limit(self, current=None): |
|
|
|
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: |
|
|
|
"""Updates the cooldown rate limit. |
|
|
|
|
|
|
|
Parameters |
|
|
|
------------- |
|
|
|
current: Optional[:class:`float`] |
|
|
|
The time in seconds since Unix epoch to update the rate limit at. |
|
|
|
If not supplied, then :func:`time.time()` is used. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Optional[:class:`float`] |
|
|
|
The retry-after time in seconds if rate limited. |
|
|
|
""" |
|
|
|
current = current or time.time() |
|
|
|
self._last = current |
|
|
|
|
|
|
@ -115,46 +174,58 @@ class Cooldown: |
|
|
|
# we're not so decrement our tokens |
|
|
|
self._tokens -= 1 |
|
|
|
|
|
|
|
def reset(self): |
|
|
|
def reset(self) -> None: |
|
|
|
"""Reset the cooldown to its initial state.""" |
|
|
|
self._tokens = self.rate |
|
|
|
self._last = 0.0 |
|
|
|
|
|
|
|
def copy(self): |
|
|
|
def copy(self) -> Cooldown: |
|
|
|
"""Creates a copy of this cooldown. |
|
|
|
|
|
|
|
Returns |
|
|
|
-------- |
|
|
|
:class:`Cooldown` |
|
|
|
A new instance of this cooldown. |
|
|
|
""" |
|
|
|
return Cooldown(self.rate, self.per) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
def __repr__(self) -> str: |
|
|
|
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>' |
|
|
|
|
|
|
|
class CooldownMapping: |
|
|
|
def __init__(self, original, type): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
original: Optional[Cooldown], |
|
|
|
type: Callable[[Message], Any], |
|
|
|
) -> None: |
|
|
|
if not callable(type): |
|
|
|
raise TypeError('Cooldown type must be a BucketType or callable') |
|
|
|
|
|
|
|
self._cache = {} |
|
|
|
self._cooldown = original |
|
|
|
self._type = type |
|
|
|
self._cache: Dict[Any, Cooldown] = {} |
|
|
|
self._cooldown: Optional[Cooldown] = original |
|
|
|
self._type: Callable[[Message], Any] = type |
|
|
|
|
|
|
|
def copy(self): |
|
|
|
def copy(self) -> CooldownMapping: |
|
|
|
ret = CooldownMapping(self._cooldown, self._type) |
|
|
|
ret._cache = self._cache.copy() |
|
|
|
return ret |
|
|
|
|
|
|
|
@property |
|
|
|
def valid(self): |
|
|
|
def valid(self) -> bool: |
|
|
|
return self._cooldown is not None |
|
|
|
|
|
|
|
@property |
|
|
|
def type(self): |
|
|
|
def type(self) -> Callable[[Message], Any]: |
|
|
|
return self._type |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_cooldown(cls, rate, per, type): |
|
|
|
def from_cooldown(cls: Type[C], rate, per, type) -> C: |
|
|
|
return cls(Cooldown(rate, per), type) |
|
|
|
|
|
|
|
def _bucket_key(self, msg): |
|
|
|
def _bucket_key(self, msg: Message) -> Any: |
|
|
|
return self._type(msg) |
|
|
|
|
|
|
|
def _verify_cache_integrity(self, current=None): |
|
|
|
def _verify_cache_integrity(self, current: Optional[float] = None) -> None: |
|
|
|
# we want to delete all cache objects that haven't been used |
|
|
|
# in a cooldown window. e.g. if we have a command that has a |
|
|
|
# cooldown of 60s and it has not been used in 60s then that key should be deleted |
|
|
@ -163,12 +234,12 @@ class CooldownMapping: |
|
|
|
for k in dead_keys: |
|
|
|
del self._cache[k] |
|
|
|
|
|
|
|
def create_bucket(self, message): |
|
|
|
return self._cooldown.copy() |
|
|
|
def create_bucket(self, message: Message) -> Cooldown: |
|
|
|
return self._cooldown.copy() # type: ignore |
|
|
|
|
|
|
|
def get_bucket(self, message, current=None): |
|
|
|
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown: |
|
|
|
if self._type is BucketType.default: |
|
|
|
return self._cooldown |
|
|
|
return self._cooldown # type: ignore |
|
|
|
|
|
|
|
self._verify_cache_integrity(current) |
|
|
|
key = self._bucket_key(message) |
|
|
@ -181,26 +252,30 @@ class CooldownMapping: |
|
|
|
|
|
|
|
return bucket |
|
|
|
|
|
|
|
def update_rate_limit(self, message, current=None): |
|
|
|
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: |
|
|
|
bucket = self.get_bucket(message, current) |
|
|
|
return bucket.update_rate_limit(current) |
|
|
|
|
|
|
|
class DynamicCooldownMapping(CooldownMapping): |
|
|
|
|
|
|
|
def __init__(self, factory, type): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
factory: Callable[[Message], Cooldown], |
|
|
|
type: Callable[[Message], Any] |
|
|
|
) -> None: |
|
|
|
super().__init__(None, type) |
|
|
|
self._factory = factory |
|
|
|
self._factory: Callable[[Message], Cooldown] = factory |
|
|
|
|
|
|
|
def copy(self): |
|
|
|
def copy(self) -> DynamicCooldownMapping: |
|
|
|
ret = DynamicCooldownMapping(self._factory, self._type) |
|
|
|
ret._cache = self._cache.copy() |
|
|
|
return ret |
|
|
|
|
|
|
|
@property |
|
|
|
def valid(self): |
|
|
|
def valid(self) -> bool: |
|
|
|
return True |
|
|
|
|
|
|
|
def create_bucket(self, message): |
|
|
|
def create_bucket(self, message: Message) -> Cooldown: |
|
|
|
return self._factory(message) |
|
|
|
|
|
|
|
class _Semaphore: |
|
|
@ -218,28 +293,28 @@ class _Semaphore: |
|
|
|
|
|
|
|
__slots__ = ('value', 'loop', '_waiters') |
|
|
|
|
|
|
|
def __init__(self, number): |
|
|
|
self.value = number |
|
|
|
self.loop = asyncio.get_event_loop() |
|
|
|
self._waiters = deque() |
|
|
|
def __init__(self, number: int) -> None: |
|
|
|
self.value: int = number |
|
|
|
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() |
|
|
|
self._waiters: Deque[asyncio.Future] = deque() |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
def __repr__(self) -> str: |
|
|
|
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' |
|
|
|
|
|
|
|
def locked(self): |
|
|
|
def locked(self) -> bool: |
|
|
|
return self.value == 0 |
|
|
|
|
|
|
|
def is_active(self): |
|
|
|
def is_active(self) -> bool: |
|
|
|
return len(self._waiters) > 0 |
|
|
|
|
|
|
|
def wake_up(self): |
|
|
|
def wake_up(self) -> None: |
|
|
|
while self._waiters: |
|
|
|
future = self._waiters.popleft() |
|
|
|
if not future.done(): |
|
|
|
future.set_result(None) |
|
|
|
return |
|
|
|
|
|
|
|
async def acquire(self, *, wait=False): |
|
|
|
async def acquire(self, *, wait: bool = False) -> bool: |
|
|
|
if not wait and self.value <= 0: |
|
|
|
# signal that we're not acquiring |
|
|
|
return False |
|
|
@ -258,18 +333,18 @@ class _Semaphore: |
|
|
|
self.value -= 1 |
|
|
|
return True |
|
|
|
|
|
|
|
def release(self): |
|
|
|
def release(self) -> None: |
|
|
|
self.value += 1 |
|
|
|
self.wake_up() |
|
|
|
|
|
|
|
class MaxConcurrency: |
|
|
|
__slots__ = ('number', 'per', 'wait', '_mapping') |
|
|
|
|
|
|
|
def __init__(self, number, *, per, wait): |
|
|
|
self._mapping = {} |
|
|
|
self.per = per |
|
|
|
self.number = number |
|
|
|
self.wait = wait |
|
|
|
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: |
|
|
|
self._mapping: Dict[Any, _Semaphore] = {} |
|
|
|
self.per: BucketType = per |
|
|
|
self.number: int = number |
|
|
|
self.wait: bool = wait |
|
|
|
|
|
|
|
if number <= 0: |
|
|
|
raise ValueError('max_concurrency \'number\' cannot be less than 1') |
|
|
@ -277,16 +352,16 @@ class MaxConcurrency: |
|
|
|
if not isinstance(per, BucketType): |
|
|
|
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') |
|
|
|
|
|
|
|
def copy(self): |
|
|
|
def copy(self: MC) -> MC: |
|
|
|
return self.__class__(self.number, per=self.per, wait=self.wait) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
def __repr__(self) -> str: |
|
|
|
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>' |
|
|
|
|
|
|
|
def get_key(self, message): |
|
|
|
def get_key(self, message: Message) -> Any: |
|
|
|
return self.per.get_key(message) |
|
|
|
|
|
|
|
async def acquire(self, message): |
|
|
|
async def acquire(self, message: Message) -> None: |
|
|
|
key = self.get_key(message) |
|
|
|
|
|
|
|
try: |
|
|
@ -298,7 +373,7 @@ class MaxConcurrency: |
|
|
|
if not acquired: |
|
|
|
raise MaxConcurrencyReached(self.number, self.per) |
|
|
|
|
|
|
|
async def release(self, message): |
|
|
|
async def release(self, message: Message) -> None: |
|
|
|
# Technically there's no reason for this function to be async |
|
|
|
# But it might be more useful in the future |
|
|
|
key = self.get_key(message) |
|
|
|