diff --git a/discord/client.py b/discord/client.py index b575caf48..7518dca0e 100644 --- a/discord/client.py +++ b/discord/client.py @@ -260,6 +260,14 @@ class Client: This allows you to check requests the library is using. For more information, check the `aiohttp documentation `_. + .. versionadded:: 2.0 + max_ratelimit_timeout: Optional[:class:`float`] + The maximum number of seconds to wait when a non-global rate limit is encountered. + If a request requires sleeping for more than the seconds passed in, then + :exc:`~discord.RateLimited` will be raised. By default, there is no timeout limit. + In order to prevent misuse and unnecessary bans, the minimum value this can be + set to is ``30.0`` seconds. + .. versionadded:: 2.0 Attributes @@ -280,12 +288,14 @@ class Client: proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) unsync_clock: bool = options.pop('assume_unsync_clock', True) http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None) + max_ratelimit_timeout: Optional[float] = options.pop('max_ratelimit_timeout', None) self.http: HTTPClient = HTTPClient( self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, http_trace=http_trace, + max_ratelimit_timeout=max_ratelimit_timeout, ) self._handlers: Dict[str, Callable[..., None]] = { diff --git a/discord/errors.py b/discord/errors.py index b3b207fe1..925d303ef 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -38,6 +38,7 @@ __all__ = ( 'ClientException', 'GatewayNotFound', 'HTTPException', + 'RateLimited', 'Forbidden', 'NotFound', 'DiscordServerError', @@ -137,6 +138,30 @@ class HTTPException(DiscordException): super().__init__(fmt.format(self.response, self.code, self.text)) +class RateLimited(DiscordException): + """Exception that's raised for when status code 429 occurs + and the timeout is greater than the configured maximum using + the ``max_ratelimit_timeout`` parameter in :class:`Client`. + + This is not raised during global ratelimits. + + Since sometimes requests are halted pre-emptively before they're + even made, **this does not subclass :exc:`HTTPException`.** + + .. versionadded:: 2.0 + + Attributes + ------------ + retry_after: :class:`float` + The amount of seconds that the client should wait before retrying + the request. + """ + + def __init__(self, retry_after: float): + self.retry_after = retry_after + super().__init__(f'Too many requests. Retry in {retry_after:.2f} seconds.') + + class Forbidden(HTTPException): """Exception that's raised for when status code 403 occurs. diff --git a/discord/http.py b/discord/http.py index dd52c9f3c..87d533d58 100644 --- a/discord/http.py +++ b/discord/http.py @@ -52,7 +52,7 @@ import datetime import aiohttp -from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound +from .errors import HTTPException, RateLimited, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound from .gateway import DiscordClientWebSocketResponse from .file import File from .mentions import AllowedMentions @@ -328,13 +328,15 @@ class Ratelimit: design is to increase throughput of requests being sent concurrently rather than forcing everything into a single lock queue per route. """ - def __init__(self) -> None: + + def __init__(self, max_ratelimit_timeout: Optional[float]) -> None: self.limit: int = 1 self.remaining: int = self.limit self.outgoing: int = 0 self.reset_after: float = 0.0 self.expires: Optional[float] = None self.dirty: bool = False + self._max_ratelimit_timeout: Optional[float] = max_ratelimit_timeout self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._pending_requests: deque[asyncio.Future[Any]] = deque() # Only a single rate limit object should be sleeping at a time. @@ -381,12 +383,15 @@ class Ratelimit: future.set_result(None) break - def _wake(self, count: int = 1) -> None: + def _wake(self, count: int = 1, *, exception: Optional[RateLimited] = None) -> None: awaken = 0 while self._pending_requests: future = self._pending_requests.popleft() if not future.done(): - future.set_result(None) + if exception: + future.set_exception(exception) + else: + future.set_result(None) self._has_just_awaken = True awaken += 1 @@ -394,10 +399,14 @@ class Ratelimit: break async def _refresh(self) -> None: + error = self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout + exception = RateLimited(self.reset_after) if error else None async with self._sleeping: - await asyncio.sleep(self.reset_after) + if not error: + await asyncio.sleep(self.reset_after) + self.reset() - self._wake(self.remaining) + self._wake(self.remaining, exception=exception) def is_expired(self) -> bool: return self.expires is not None and self._loop.time() > self.expires @@ -406,6 +415,12 @@ class Ratelimit: if self.is_expired(): self.reset() + if self._max_ratelimit_timeout is not None and self.expires is not None: + # Check if we can pre-emptively block this request for having too large of a timeout + current_reset_after = self.expires - self._loop.time() + if current_reset_after > self._max_ratelimit_timeout: + raise RateLimited(current_reset_after) + while self.remaining <= 0: future = self._loop.create_future() self._pending_requests.append(future) @@ -433,7 +448,12 @@ class Ratelimit: if tokens <= 0: await self._refresh() elif self._pending_requests: - self._wake(tokens) + exception = ( + RateLimited(self.reset_after) + if self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout + else None + ) + self._wake(tokens, exception=exception) # For some reason, the Discord voice websocket expects this header to be @@ -453,6 +473,7 @@ class HTTPClient: proxy_auth: Optional[aiohttp.BasicAuth] = None, unsync_clock: bool = True, http_trace: Optional[aiohttp.TraceConfig] = None, + max_ratelimit_timeout: Optional[float] = None, ) -> None: self.loop: asyncio.AbstractEventLoop = loop self.connector: aiohttp.BaseConnector = connector or MISSING @@ -472,6 +493,7 @@ class HTTPClient: self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth self.http_trace: Optional[aiohttp.TraceConfig] = http_trace self.use_clock: bool = not unsync_clock + self.max_ratelimit_timeout: Optional[float] = max(30.0, max_ratelimit_timeout) if max_ratelimit_timeout else None user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) @@ -520,7 +542,7 @@ class HTTPClient: try: ratelimit = mapping[key] except KeyError: - mapping[key] = ratelimit = Ratelimit() + mapping[key] = ratelimit = Ratelimit(self.max_ratelimit_timeout) # header creation headers: Dict[str, str] = { @@ -628,10 +650,17 @@ class HTTPClient: # Banned by Cloudflare more than likely. raise HTTPException(response, data) - fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.' - - # sleep a bit retry_after: float = data['retry_after'] + if self.max_ratelimit_timeout and retry_after > self.max_ratelimit_timeout: + _log.warning( + 'We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead.', + method, + url, + retry_after, + ) + raise RateLimited(retry_after) + + fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.' _log.warning(fmt, method, url, retry_after, stack_info=True) _log.debug( diff --git a/docs/api.rst b/docs/api.rst index e4f7c2fde..0f039e858 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4692,6 +4692,9 @@ The following exceptions are thrown by the library. .. autoexception:: HTTPException :members: +.. autoexception:: RateLimited + :members: + .. autoexception:: Forbidden .. autoexception:: NotFound @@ -4730,3 +4733,4 @@ Exception Hierarchy - :exc:`Forbidden` - :exc:`NotFound` - :exc:`DiscordServerError` + - :exc:`RateLimited`