diff --git a/discord/app_commands/checks.py b/discord/app_commands/checks.py index 7b7d45041..fcc97260a 100644 --- a/discord/app_commands/checks.py +++ b/discord/app_commands/checks.py @@ -112,7 +112,8 @@ class Cooldown: if not current: current = time.time() - tokens = self._tokens + # the calculated tokens should be non-negative + tokens = max(self._tokens, 0) if current > self._window + self.per: tokens = self.rate @@ -140,7 +141,7 @@ class Cooldown: return 0.0 - def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: + def update_rate_limit(self, current: Optional[float] = None, *, tokens: int = 1) -> Optional[float]: """Updates the cooldown rate limit. Parameters @@ -148,6 +149,8 @@ class Cooldown: current: Optional[:class:`float`] The time in seconds since Unix epoch to update the rate limit at. If not supplied, then :func:`time.time()` is used. + tokens: :class:`int` + The amount of tokens to deduct from the rate limit. Returns ------- @@ -163,12 +166,12 @@ class Cooldown: if self._tokens == self.rate: self._window = current - # check if we are rate limited - if self._tokens == 0: - return self.per - (current - self._window) + # decrement tokens by specified number + self._tokens -= tokens - # we're not so decrement our tokens - self._tokens -= 1 + # check if we are rate limited and return retry-after + if self._tokens < 0: + return self.per - (current - self._window) def reset(self) -> None: """Reset the cooldown to its initial state.""" diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 534658728..fc57e22e0 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -140,9 +140,9 @@ class CooldownMapping: return bucket - def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: + def update_rate_limit(self, message: Message, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: bucket = self.get_bucket(message, current) - return bucket.update_rate_limit(current) + return bucket.update_rate_limit(current, tokens=tokens) class DynamicCooldownMapping(CooldownMapping):