Browse Source

[commands] Document / type-hint cooldown

pull/7362/head
Josh 4 years ago
committed by GitHub
parent
commit
1c63816cc0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      discord/client.py
  2. 189
      discord/ext/commands/cooldowns.py
  3. 2
      discord/ext/commands/errors.py
  4. 8
      docs/ext/commands/api.rst

2
discord/client.py

@ -289,7 +289,7 @@ class Client:
@property @property
def stickers(self) -> List[GuildSticker]: def stickers(self) -> List[GuildSticker]:
"""List[:class:`GuildSticker`]: The stickers that the connected client has. """List[:class:`.GuildSticker`]: The stickers that the connected client has.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """

189
discord/ext/commands/cooldowns.py

@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
from discord.enums import Enum from discord.enums import Enum
import time import time
import asyncio import asyncio
@ -30,6 +34,9 @@ from collections import deque
from ...abc import PrivateChannel from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from ...message import Message
__all__ = ( __all__ = (
'BucketType', 'BucketType',
'Cooldown', 'Cooldown',
@ -38,6 +45,9 @@ __all__ = (
'MaxConcurrency', 'MaxConcurrency',
) )
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
class BucketType(Enum): class BucketType(Enum):
default = 0 default = 0
user = 1 user = 1
@ -47,7 +57,7 @@ class BucketType(Enum):
category = 5 category = 5
role = 6 role = 6
def get_key(self, msg): def get_key(self, msg: Message) -> Any:
if self is BucketType.user: if self is BucketType.user:
return msg.author.id return msg.author.id
elif self is BucketType.guild: elif self is BucketType.guild:
@ -57,29 +67,52 @@ class BucketType(Enum):
elif self is BucketType.member: elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id) return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category: elif self is BucketType.category:
return (msg.channel.category or msg.channel).id return (msg.channel.category or msg.channel).id # type: ignore
elif self is BucketType.role: elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds # we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role # and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
def __call__(self, msg): def __call__(self, msg: Message) -> Any:
return self.get_key(msg) return self.get_key(msg)
class Cooldown: class Cooldown:
__slots__ = ('rate', 'per', '_window', '_tokens', '_last') """Represents a cooldown for a command.
Attributes
-----------
rate: :class:`int`
The total number of tokens available per :attr:`per` seconds.
per: :class:`float`
The length of the cooldown period in seconds.
"""
def __init__(self, rate, per): __slots__ = ('rate', 'per', '_window', '_tokens', '_last')
self.rate = int(rate)
self.per = float(per)
self._window = 0.0
self._tokens = self.rate
self._last = 0.0
def get_tokens(self, current=None): def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
self.per: float = float(per)
self._window: float = 0.0
self._tokens: int = self.rate
self._last: float = 0.0
def get_tokens(self, current: Optional[float] = None) -> int:
"""Returns the number of available tokens before rate limiting is applied.
Parameters
------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to calculate tokens at.
If not supplied then :func:`time.time()` is used.
Returns
--------
:class:`int`
The number of tokens available before the cooldown is to be applied.
"""
if not current: if not current:
current = time.time() current = time.time()
@ -89,7 +122,20 @@ class Cooldown:
tokens = self.rate tokens = self.rate
return tokens return tokens
def get_retry_after(self, current=None): def get_retry_after(self, current: Optional[float] = None) -> float:
"""Returns the time in seconds until the cooldown will be reset.
Parameters
-------------
current: Optional[:class:`float`]
The current time in seconds since Unix epoch.
If not supplied, then :func:`time.time()` is used.
Returns
-------
:class:`float`
The number of seconds to wait before this cooldown will be reset.
"""
current = current or time.time() current = current or time.time()
tokens = self.get_tokens(current) tokens = self.get_tokens(current)
@ -98,7 +144,20 @@ class Cooldown:
return 0.0 return 0.0
def update_rate_limit(self, current=None): def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
"""Updates the cooldown rate limit.
Parameters
-------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to update the rate limit at.
If not supplied, then :func:`time.time()` is used.
Returns
-------
Optional[:class:`float`]
The retry-after time in seconds if rate limited.
"""
current = current or time.time() current = current or time.time()
self._last = current self._last = current
@ -115,46 +174,58 @@ class Cooldown:
# we're not so decrement our tokens # we're not so decrement our tokens
self._tokens -= 1 self._tokens -= 1
def reset(self): def reset(self) -> None:
"""Reset the cooldown to its initial state."""
self._tokens = self.rate self._tokens = self.rate
self._last = 0.0 self._last = 0.0
def copy(self): def copy(self) -> Cooldown:
"""Creates a copy of this cooldown.
Returns
--------
:class:`Cooldown`
A new instance of this cooldown.
"""
return Cooldown(self.rate, self.per) return Cooldown(self.rate, self.per)
def __repr__(self): def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>' return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping: class CooldownMapping:
def __init__(self, original, type): def __init__(
self,
original: Optional[Cooldown],
type: Callable[[Message], Any],
) -> None:
if not callable(type): if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable') raise TypeError('Cooldown type must be a BucketType or callable')
self._cache = {} self._cache: Dict[Any, Cooldown] = {}
self._cooldown = original self._cooldown: Optional[Cooldown] = original
self._type = type self._type: Callable[[Message], Any] = type
def copy(self): def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type) ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy() ret._cache = self._cache.copy()
return ret return ret
@property @property
def valid(self): def valid(self) -> bool:
return self._cooldown is not None return self._cooldown is not None
@property @property
def type(self): def type(self) -> Callable[[Message], Any]:
return self._type return self._type
@classmethod @classmethod
def from_cooldown(cls, rate, per, type): def from_cooldown(cls: Type[C], rate, per, type) -> C:
return cls(Cooldown(rate, per), type) return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg): def _bucket_key(self, msg: Message) -> Any:
return self._type(msg) return self._type(msg)
def _verify_cache_integrity(self, current=None): def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
# we want to delete all cache objects that haven't been used # we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a # in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted # cooldown of 60s and it has not been used in 60s then that key should be deleted
@ -163,12 +234,12 @@ class CooldownMapping:
for k in dead_keys: for k in dead_keys:
del self._cache[k] del self._cache[k]
def create_bucket(self, message): def create_bucket(self, message: Message) -> Cooldown:
return self._cooldown.copy() return self._cooldown.copy() # type: ignore
def get_bucket(self, message, current=None): def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
if self._type is BucketType.default: if self._type is BucketType.default:
return self._cooldown return self._cooldown # type: ignore
self._verify_cache_integrity(current) self._verify_cache_integrity(current)
key = self._bucket_key(message) key = self._bucket_key(message)
@ -181,26 +252,30 @@ class CooldownMapping:
return bucket return bucket
def update_rate_limit(self, message, current=None): def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current) return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping): class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory, type): def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
) -> None:
super().__init__(None, type) super().__init__(None, type)
self._factory = factory self._factory: Callable[[Message], Cooldown] = factory
def copy(self): def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type) ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy() ret._cache = self._cache.copy()
return ret return ret
@property @property
def valid(self): def valid(self) -> bool:
return True return True
def create_bucket(self, message): def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message) return self._factory(message)
class _Semaphore: class _Semaphore:
@ -218,28 +293,28 @@ class _Semaphore:
__slots__ = ('value', 'loop', '_waiters') __slots__ = ('value', 'loop', '_waiters')
def __init__(self, number): def __init__(self, number: int) -> None:
self.value = number self.value: int = number
self.loop = asyncio.get_event_loop() self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self._waiters = deque() self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self): def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
def locked(self): def locked(self) -> bool:
return self.value == 0 return self.value == 0
def is_active(self): def is_active(self) -> bool:
return len(self._waiters) > 0 return len(self._waiters) > 0
def wake_up(self): def wake_up(self) -> None:
while self._waiters: while self._waiters:
future = self._waiters.popleft() future = self._waiters.popleft()
if not future.done(): if not future.done():
future.set_result(None) future.set_result(None)
return return
async def acquire(self, *, wait=False): async def acquire(self, *, wait: bool = False) -> bool:
if not wait and self.value <= 0: if not wait and self.value <= 0:
# signal that we're not acquiring # signal that we're not acquiring
return False return False
@ -258,18 +333,18 @@ class _Semaphore:
self.value -= 1 self.value -= 1
return True return True
def release(self): def release(self) -> None:
self.value += 1 self.value += 1
self.wake_up() self.wake_up()
class MaxConcurrency: class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping') __slots__ = ('number', 'per', 'wait', '_mapping')
def __init__(self, number, *, per, wait): def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping = {} self._mapping: Dict[Any, _Semaphore] = {}
self.per = per self.per: BucketType = per
self.number = number self.number: int = number
self.wait = wait self.wait: bool = wait
if number <= 0: if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1') raise ValueError('max_concurrency \'number\' cannot be less than 1')
@ -277,16 +352,16 @@ class MaxConcurrency:
if not isinstance(per, BucketType): if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
def copy(self): def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait) return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self): def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>' return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
def get_key(self, message): def get_key(self, message: Message) -> Any:
return self.per.get_key(message) return self.per.get_key(message)
async def acquire(self, message): async def acquire(self, message: Message) -> None:
key = self.get_key(message) key = self.get_key(message)
try: try:
@ -298,7 +373,7 @@ class MaxConcurrency:
if not acquired: if not acquired:
raise MaxConcurrencyReached(self.number, self.per) raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message): async def release(self, message: Message) -> None:
# Technically there's no reason for this function to be async # Technically there's no reason for this function to be async
# But it might be more useful in the future # But it might be more useful in the future
key = self.get_key(message) key = self.get_key(message)

2
discord/ext/commands/errors.py

@ -493,7 +493,7 @@ class CommandOnCooldown(CommandError):
Attributes Attributes
----------- -----------
cooldown: ``Cooldown`` cooldown: :class:`.Cooldown`
A class with attributes ``rate`` and ``per`` similar to the A class with attributes ``rate`` and ``per`` similar to the
:func:`.cooldown` decorator. :func:`.cooldown` decorator.
type: :class:`BucketType` type: :class:`BucketType`

8
docs/ext/commands/api.rst

@ -330,6 +330,14 @@ Checks
.. _ext_commands_api_context: .. _ext_commands_api_context:
Cooldown
---------
.. attributetable:: discord.ext.commands.Cooldown
.. autoclass:: discord.ext.commands.Cooldown
:members:
Context Context
-------- --------

Loading…
Cancel
Save