From 1c63816cc057ef81c3f32cb724614d059f878198 Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 10 Aug 2021 22:35:15 +1000 Subject: [PATCH] [commands] Document / type-hint cooldown --- discord/client.py | 2 +- discord/ext/commands/cooldowns.py | 189 +++++++++++++++++++++--------- discord/ext/commands/errors.py | 2 +- docs/ext/commands/api.rst | 8 ++ 4 files changed, 142 insertions(+), 59 deletions(-) diff --git a/discord/client.py b/discord/client.py index a41e061c6..c629b3c99 100644 --- a/discord/client.py +++ b/discord/client.py @@ -289,7 +289,7 @@ class Client: @property def stickers(self) -> List[GuildSticker]: - """List[:class:`GuildSticker`]: The stickers that the connected client has. + """List[:class:`.GuildSticker`]: The stickers that the connected client has. .. versionadded:: 2.0 """ diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 6092909ba..2e008aed4 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + + +from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING from discord.enums import Enum import time import asyncio @@ -30,6 +34,9 @@ from collections import deque from ...abc import PrivateChannel from .errors import MaxConcurrencyReached +if TYPE_CHECKING: + from ...message import Message + __all__ = ( 'BucketType', 'Cooldown', @@ -38,6 +45,9 @@ __all__ = ( 'MaxConcurrency', ) +C = TypeVar('C', bound='CooldownMapping') +MC = TypeVar('MC', bound='MaxConcurrency') + class BucketType(Enum): default = 0 user = 1 @@ -47,7 +57,7 @@ class BucketType(Enum): category = 5 role = 6 - def get_key(self, msg): + def get_key(self, msg: Message) -> Any: if self is BucketType.user: return msg.author.id elif self is BucketType.guild: @@ -57,29 +67,52 @@ class BucketType(Enum): elif self is BucketType.member: return ((msg.guild and msg.guild.id), msg.author.id) elif self is BucketType.category: - return (msg.channel.category or msg.channel).id + return (msg.channel.category or msg.channel).id # type: ignore elif self is BucketType.role: # we return the channel id of a private-channel as there are only roles in guilds # and that yields the same result as for a guild with only the @everyone role # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do - return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id + return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore - def __call__(self, msg): + def __call__(self, msg: Message) -> Any: return self.get_key(msg) class Cooldown: - __slots__ = ('rate', 'per', '_window', '_tokens', '_last') + """Represents a cooldown for a command. + + Attributes + ----------- + rate: :class:`int` + The total number of tokens available per :attr:`per` seconds. + per: :class:`float` + The length of the cooldown period in seconds. + """ - def __init__(self, rate, per): - self.rate = int(rate) - self.per = float(per) - self._window = 0.0 - self._tokens = self.rate - self._last = 0.0 + __slots__ = ('rate', 'per', '_window', '_tokens', '_last') - def get_tokens(self, current=None): + def __init__(self, rate: float, per: float) -> None: + self.rate: int = int(rate) + self.per: float = float(per) + self._window: float = 0.0 + self._tokens: int = self.rate + self._last: float = 0.0 + + def get_tokens(self, current: Optional[float] = None) -> int: + """Returns the number of available tokens before rate limiting is applied. + + Parameters + ------------ + current: Optional[:class:`float`] + The time in seconds since Unix epoch to calculate tokens at. + If not supplied then :func:`time.time()` is used. + + Returns + -------- + :class:`int` + The number of tokens available before the cooldown is to be applied. + """ if not current: current = time.time() @@ -89,7 +122,20 @@ class Cooldown: tokens = self.rate return tokens - def get_retry_after(self, current=None): + def get_retry_after(self, current: Optional[float] = None) -> float: + """Returns the time in seconds until the cooldown will be reset. + + Parameters + ------------- + current: Optional[:class:`float`] + The current time in seconds since Unix epoch. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + :class:`float` + The number of seconds to wait before this cooldown will be reset. + """ current = current or time.time() tokens = self.get_tokens(current) @@ -98,7 +144,20 @@ class Cooldown: return 0.0 - def update_rate_limit(self, current=None): + def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: + """Updates the cooldown rate limit. + + Parameters + ------------- + current: Optional[:class:`float`] + The time in seconds since Unix epoch to update the rate limit at. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + Optional[:class:`float`] + The retry-after time in seconds if rate limited. + """ current = current or time.time() self._last = current @@ -115,46 +174,58 @@ class Cooldown: # we're not so decrement our tokens self._tokens -= 1 - def reset(self): + def reset(self) -> None: + """Reset the cooldown to its initial state.""" self._tokens = self.rate self._last = 0.0 - def copy(self): + def copy(self) -> Cooldown: + """Creates a copy of this cooldown. + + Returns + -------- + :class:`Cooldown` + A new instance of this cooldown. + """ return Cooldown(self.rate, self.per) - def __repr__(self): + def __repr__(self) -> str: return f'' class CooldownMapping: - def __init__(self, original, type): + def __init__( + self, + original: Optional[Cooldown], + type: Callable[[Message], Any], + ) -> None: if not callable(type): raise TypeError('Cooldown type must be a BucketType or callable') - self._cache = {} - self._cooldown = original - self._type = type + self._cache: Dict[Any, Cooldown] = {} + self._cooldown: Optional[Cooldown] = original + self._type: Callable[[Message], Any] = type - def copy(self): + def copy(self) -> CooldownMapping: ret = CooldownMapping(self._cooldown, self._type) ret._cache = self._cache.copy() return ret @property - def valid(self): + def valid(self) -> bool: return self._cooldown is not None @property - def type(self): + def type(self) -> Callable[[Message], Any]: return self._type @classmethod - def from_cooldown(cls, rate, per, type): + def from_cooldown(cls: Type[C], rate, per, type) -> C: return cls(Cooldown(rate, per), type) - def _bucket_key(self, msg): + def _bucket_key(self, msg: Message) -> Any: return self._type(msg) - def _verify_cache_integrity(self, current=None): + def _verify_cache_integrity(self, current: Optional[float] = None) -> None: # we want to delete all cache objects that haven't been used # in a cooldown window. e.g. if we have a command that has a # cooldown of 60s and it has not been used in 60s then that key should be deleted @@ -163,12 +234,12 @@ class CooldownMapping: for k in dead_keys: del self._cache[k] - def create_bucket(self, message): - return self._cooldown.copy() + def create_bucket(self, message: Message) -> Cooldown: + return self._cooldown.copy() # type: ignore - def get_bucket(self, message, current=None): + def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown: if self._type is BucketType.default: - return self._cooldown + return self._cooldown # type: ignore self._verify_cache_integrity(current) key = self._bucket_key(message) @@ -181,26 +252,30 @@ class CooldownMapping: return bucket - def update_rate_limit(self, message, current=None): + def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) class DynamicCooldownMapping(CooldownMapping): - def __init__(self, factory, type): + def __init__( + self, + factory: Callable[[Message], Cooldown], + type: Callable[[Message], Any] + ) -> None: super().__init__(None, type) - self._factory = factory + self._factory: Callable[[Message], Cooldown] = factory - def copy(self): + def copy(self) -> DynamicCooldownMapping: ret = DynamicCooldownMapping(self._factory, self._type) ret._cache = self._cache.copy() return ret @property - def valid(self): + def valid(self) -> bool: return True - def create_bucket(self, message): + def create_bucket(self, message: Message) -> Cooldown: return self._factory(message) class _Semaphore: @@ -218,28 +293,28 @@ class _Semaphore: __slots__ = ('value', 'loop', '_waiters') - def __init__(self, number): - self.value = number - self.loop = asyncio.get_event_loop() - self._waiters = deque() + def __init__(self, number: int) -> None: + self.value: int = number + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._waiters: Deque[asyncio.Future] = deque() - def __repr__(self): + def __repr__(self) -> str: return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' - def locked(self): + def locked(self) -> bool: return self.value == 0 - def is_active(self): + def is_active(self) -> bool: return len(self._waiters) > 0 - def wake_up(self): + def wake_up(self) -> None: while self._waiters: future = self._waiters.popleft() if not future.done(): future.set_result(None) return - async def acquire(self, *, wait=False): + async def acquire(self, *, wait: bool = False) -> bool: if not wait and self.value <= 0: # signal that we're not acquiring return False @@ -258,18 +333,18 @@ class _Semaphore: self.value -= 1 return True - def release(self): + def release(self) -> None: self.value += 1 self.wake_up() class MaxConcurrency: __slots__ = ('number', 'per', 'wait', '_mapping') - def __init__(self, number, *, per, wait): - self._mapping = {} - self.per = per - self.number = number - self.wait = wait + def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: + self._mapping: Dict[Any, _Semaphore] = {} + self.per: BucketType = per + self.number: int = number + self.wait: bool = wait if number <= 0: raise ValueError('max_concurrency \'number\' cannot be less than 1') @@ -277,16 +352,16 @@ class MaxConcurrency: if not isinstance(per, BucketType): raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') - def copy(self): + def copy(self: MC) -> MC: return self.__class__(self.number, per=self.per, wait=self.wait) - def __repr__(self): + def __repr__(self) -> str: return f'' - def get_key(self, message): + def get_key(self, message: Message) -> Any: return self.per.get_key(message) - async def acquire(self, message): + async def acquire(self, message: Message) -> None: key = self.get_key(message) try: @@ -298,7 +373,7 @@ class MaxConcurrency: if not acquired: raise MaxConcurrencyReached(self.number, self.per) - async def release(self, message): + async def release(self, message: Message) -> None: # Technically there's no reason for this function to be async # But it might be more useful in the future key = self.get_key(message) diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 5a581d51e..9c9c7bab5 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -493,7 +493,7 @@ class CommandOnCooldown(CommandError): Attributes ----------- - cooldown: ``Cooldown`` + cooldown: :class:`.Cooldown` A class with attributes ``rate`` and ``per`` similar to the :func:`.cooldown` decorator. type: :class:`BucketType` diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 93b01343b..e96315145 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -330,6 +330,14 @@ Checks .. _ext_commands_api_context: +Cooldown +--------- + +.. attributetable:: discord.ext.commands.Cooldown + +.. autoclass:: discord.ext.commands.Cooldown + :members: + Context --------