Browse Source

Allow configuring the maximum ratelimit timeout before erroring

This is useful for cases where a rate limit is known to be
extraordinarily high, but you still want to handle the error.
This is common with routes such as emoji creation.
pull/8258/head
Rapptz 3 years ago
parent
commit
76402b00f9
  1. 10
      discord/client.py
  2. 25
      discord/errors.py
  3. 51
      discord/http.py
  4. 4
      docs/api.rst

10
discord/client.py

@ -260,6 +260,14 @@ class Client:
This allows you to check requests the library is using. For more information, check the This allows you to check requests the library is using. For more information, check the
`aiohttp documentation <https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing>`_. `aiohttp documentation <https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing>`_.
.. versionadded:: 2.0
max_ratelimit_timeout: Optional[:class:`float`]
The maximum number of seconds to wait when a non-global rate limit is encountered.
If a request requires sleeping for more than the seconds passed in, then
:exc:`~discord.RateLimited` will be raised. By default, there is no timeout limit.
In order to prevent misuse and unnecessary bans, the minimum value this can be
set to is ``30.0`` seconds.
.. versionadded:: 2.0 .. versionadded:: 2.0
Attributes Attributes
@ -280,12 +288,14 @@ class Client:
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop('assume_unsync_clock', True)
http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None) http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None)
max_ratelimit_timeout: Optional[float] = options.pop('max_ratelimit_timeout', None)
self.http: HTTPClient = HTTPClient( self.http: HTTPClient = HTTPClient(
self.loop, self.loop,
proxy=proxy, proxy=proxy,
proxy_auth=proxy_auth, proxy_auth=proxy_auth,
unsync_clock=unsync_clock, unsync_clock=unsync_clock,
http_trace=http_trace, http_trace=http_trace,
max_ratelimit_timeout=max_ratelimit_timeout,
) )
self._handlers: Dict[str, Callable[..., None]] = { self._handlers: Dict[str, Callable[..., None]] = {

25
discord/errors.py

@ -38,6 +38,7 @@ __all__ = (
'ClientException', 'ClientException',
'GatewayNotFound', 'GatewayNotFound',
'HTTPException', 'HTTPException',
'RateLimited',
'Forbidden', 'Forbidden',
'NotFound', 'NotFound',
'DiscordServerError', 'DiscordServerError',
@ -137,6 +138,30 @@ class HTTPException(DiscordException):
super().__init__(fmt.format(self.response, self.code, self.text)) super().__init__(fmt.format(self.response, self.code, self.text))
class RateLimited(DiscordException):
"""Exception that's raised for when status code 429 occurs
and the timeout is greater than the configured maximum using
the ``max_ratelimit_timeout`` parameter in :class:`Client`.
This is not raised during global ratelimits.
Since sometimes requests are halted pre-emptively before they're
even made, **this does not subclass :exc:`HTTPException`.**
.. versionadded:: 2.0
Attributes
------------
retry_after: :class:`float`
The amount of seconds that the client should wait before retrying
the request.
"""
def __init__(self, retry_after: float):
self.retry_after = retry_after
super().__init__(f'Too many requests. Retry in {retry_after:.2f} seconds.')
class Forbidden(HTTPException): class Forbidden(HTTPException):
"""Exception that's raised for when status code 403 occurs. """Exception that's raised for when status code 403 occurs.

51
discord/http.py

@ -52,7 +52,7 @@ import datetime
import aiohttp import aiohttp
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound from .errors import HTTPException, RateLimited, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound
from .gateway import DiscordClientWebSocketResponse from .gateway import DiscordClientWebSocketResponse
from .file import File from .file import File
from .mentions import AllowedMentions from .mentions import AllowedMentions
@ -328,13 +328,15 @@ class Ratelimit:
design is to increase throughput of requests being sent concurrently rather than forcing design is to increase throughput of requests being sent concurrently rather than forcing
everything into a single lock queue per route. everything into a single lock queue per route.
""" """
def __init__(self) -> None:
def __init__(self, max_ratelimit_timeout: Optional[float]) -> None:
self.limit: int = 1 self.limit: int = 1
self.remaining: int = self.limit self.remaining: int = self.limit
self.outgoing: int = 0 self.outgoing: int = 0
self.reset_after: float = 0.0 self.reset_after: float = 0.0
self.expires: Optional[float] = None self.expires: Optional[float] = None
self.dirty: bool = False self.dirty: bool = False
self._max_ratelimit_timeout: Optional[float] = max_ratelimit_timeout
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._pending_requests: deque[asyncio.Future[Any]] = deque() self._pending_requests: deque[asyncio.Future[Any]] = deque()
# Only a single rate limit object should be sleeping at a time. # Only a single rate limit object should be sleeping at a time.
@ -381,12 +383,15 @@ class Ratelimit:
future.set_result(None) future.set_result(None)
break break
def _wake(self, count: int = 1) -> None: def _wake(self, count: int = 1, *, exception: Optional[RateLimited] = None) -> None:
awaken = 0 awaken = 0
while self._pending_requests: while self._pending_requests:
future = self._pending_requests.popleft() future = self._pending_requests.popleft()
if not future.done(): if not future.done():
future.set_result(None) if exception:
future.set_exception(exception)
else:
future.set_result(None)
self._has_just_awaken = True self._has_just_awaken = True
awaken += 1 awaken += 1
@ -394,10 +399,14 @@ class Ratelimit:
break break
async def _refresh(self) -> None: async def _refresh(self) -> None:
error = self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout
exception = RateLimited(self.reset_after) if error else None
async with self._sleeping: async with self._sleeping:
await asyncio.sleep(self.reset_after) if not error:
await asyncio.sleep(self.reset_after)
self.reset() self.reset()
self._wake(self.remaining) self._wake(self.remaining, exception=exception)
def is_expired(self) -> bool: def is_expired(self) -> bool:
return self.expires is not None and self._loop.time() > self.expires return self.expires is not None and self._loop.time() > self.expires
@ -406,6 +415,12 @@ class Ratelimit:
if self.is_expired(): if self.is_expired():
self.reset() self.reset()
if self._max_ratelimit_timeout is not None and self.expires is not None:
# Check if we can pre-emptively block this request for having too large of a timeout
current_reset_after = self.expires - self._loop.time()
if current_reset_after > self._max_ratelimit_timeout:
raise RateLimited(current_reset_after)
while self.remaining <= 0: while self.remaining <= 0:
future = self._loop.create_future() future = self._loop.create_future()
self._pending_requests.append(future) self._pending_requests.append(future)
@ -433,7 +448,12 @@ class Ratelimit:
if tokens <= 0: if tokens <= 0:
await self._refresh() await self._refresh()
elif self._pending_requests: elif self._pending_requests:
self._wake(tokens) exception = (
RateLimited(self.reset_after)
if self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout
else None
)
self._wake(tokens, exception=exception)
# For some reason, the Discord voice websocket expects this header to be # For some reason, the Discord voice websocket expects this header to be
@ -453,6 +473,7 @@ class HTTPClient:
proxy_auth: Optional[aiohttp.BasicAuth] = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
unsync_clock: bool = True, unsync_clock: bool = True,
http_trace: Optional[aiohttp.TraceConfig] = None, http_trace: Optional[aiohttp.TraceConfig] = None,
max_ratelimit_timeout: Optional[float] = None,
) -> None: ) -> None:
self.loop: asyncio.AbstractEventLoop = loop self.loop: asyncio.AbstractEventLoop = loop
self.connector: aiohttp.BaseConnector = connector or MISSING self.connector: aiohttp.BaseConnector = connector or MISSING
@ -472,6 +493,7 @@ class HTTPClient:
self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth
self.http_trace: Optional[aiohttp.TraceConfig] = http_trace self.http_trace: Optional[aiohttp.TraceConfig] = http_trace
self.use_clock: bool = not unsync_clock self.use_clock: bool = not unsync_clock
self.max_ratelimit_timeout: Optional[float] = max(30.0, max_ratelimit_timeout) if max_ratelimit_timeout else None
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
@ -520,7 +542,7 @@ class HTTPClient:
try: try:
ratelimit = mapping[key] ratelimit = mapping[key]
except KeyError: except KeyError:
mapping[key] = ratelimit = Ratelimit() mapping[key] = ratelimit = Ratelimit(self.max_ratelimit_timeout)
# header creation # header creation
headers: Dict[str, str] = { headers: Dict[str, str] = {
@ -628,10 +650,17 @@ class HTTPClient:
# Banned by Cloudflare more than likely. # Banned by Cloudflare more than likely.
raise HTTPException(response, data) raise HTTPException(response, data)
fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.'
# sleep a bit
retry_after: float = data['retry_after'] retry_after: float = data['retry_after']
if self.max_ratelimit_timeout and retry_after > self.max_ratelimit_timeout:
_log.warning(
'We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead.',
method,
url,
retry_after,
)
raise RateLimited(retry_after)
fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.'
_log.warning(fmt, method, url, retry_after, stack_info=True) _log.warning(fmt, method, url, retry_after, stack_info=True)
_log.debug( _log.debug(

4
docs/api.rst

@ -4692,6 +4692,9 @@ The following exceptions are thrown by the library.
.. autoexception:: HTTPException .. autoexception:: HTTPException
:members: :members:
.. autoexception:: RateLimited
:members:
.. autoexception:: Forbidden .. autoexception:: Forbidden
.. autoexception:: NotFound .. autoexception:: NotFound
@ -4730,3 +4733,4 @@ Exception Hierarchy
- :exc:`Forbidden` - :exc:`Forbidden`
- :exc:`NotFound` - :exc:`NotFound`
- :exc:`DiscordServerError` - :exc:`DiscordServerError`
- :exc:`RateLimited`

Loading…
Cancel
Save