diff --git a/discord/http.py b/discord/http.py index 7b82fddb6..427334461 100644 --- a/discord/http.py +++ b/discord/http.py @@ -26,6 +26,7 @@ from __future__ import annotations import asyncio import logging +import time import sys from typing import ( Any, @@ -46,7 +47,6 @@ from typing import ( Union, ) from urllib.parse import quote as _uriquote -from collections import deque import datetime import aiohttp @@ -343,165 +343,181 @@ class Route: str(k) for k in (self.channel_id, self.guild_id, self.webhook_id, self.webhook_token) if k is not None ) - class Ratelimit: """Represents a Discord rate limit. This is similar to a semaphore except tailored to Discord's rate limits. This is aware of the expiry of a token window, along with the number of tokens available. The goal of this - design is to increase throughput of requests being sent concurrently rather than forcing - everything into a single lock queue per route. + design is to increase throughput of requests being sent concurrently by forcing everything + into a single lock queue per route. """ __slots__ = ( + 'key', 'limit', 'remaining', + 'reset_at', + 'pending', 'outgoing', - 'reset_after', - 'expires', - 'dirty', + 'one_shot', + 'http', '_last_request', - '_max_ratelimit_timeout', - '_loop', - '_pending_requests', - '_sleeping', + '_lock', + '_event', ) - def __init__(self, max_ratelimit_timeout: Optional[float]) -> None: + def __init__(self, http: HTTPClient, key: str) -> None: + self.key = key self.limit: int = 1 self.remaining: int = self.limit + self.reset_at: float = 0.0 + self.pending: int = 0 self.outgoing: int = 0 - self.reset_after: float = 0.0 - self.expires: Optional[float] = None - self.dirty: bool = False - self._max_ratelimit_timeout: Optional[float] = max_ratelimit_timeout - self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() - self._pending_requests: deque[asyncio.Future[Any]] = deque() - # Only a single rate limit object should be sleeping at a time. - # The object that is sleeping is ultimately responsible for freeing the semaphore - # for the requests currently pending. - self._sleeping: asyncio.Lock = asyncio.Lock() - self._last_request: float = self._loop.time() + self.one_shot: bool = False + self.http: HTTPClient = http + self._last_request: float = 0.0 + self._lock: asyncio.Lock = asyncio.Lock() + self._event: asyncio.Event = asyncio.Event() def __repr__(self) -> str: return ( - f'' + f'' ) - def reset(self): + def no_headers(self) -> None: + self.one_shot = True + self._event.set() + self._event.clear() + + def reset(self) -> None: self.remaining = self.limit - self.outgoing - self.expires = None - self.reset_after = 0.0 - self.dirty = False + self.reset_at = 0.0 - def update(self, response: aiohttp.ClientResponse, *, use_clock: bool = False) -> None: - headers = response.headers - self.limit = int(headers.get('X-Ratelimit-Limit', 1)) + def update(self, response: aiohttp.ClientResponse) -> bool: - if self.dirty: - self.remaining = min(int(headers.get('X-Ratelimit-Remaining', 0)), self.limit - self.outgoing) + # Shared scope 429 has longer "reset_at", determined using the retry-after field + limit = int(response.headers['X-Ratelimit-Limit']) + if response.headers.get('X-RateLimit-Scope') == 'shared': + reset_at = self.http.loop.time() + float(response.headers['Retry-After']) + remaining = 0 else: - self.remaining = int(headers.get('X-Ratelimit-Remaining', 0)) - self.dirty = True - - reset_after = headers.get('X-Ratelimit-Reset-After') - if use_clock or not reset_after: - utc = datetime.timezone.utc - now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(headers['X-Ratelimit-Reset']), utc) - self.reset_after = (reset - now).total_seconds() + + # Consider a lower remaining value because updates can be out of order, so self.outgoing is used + reset_at = self.http.loop.time() + (float(response.headers['X-Ratelimit-Reset']) - time.time()) + remaining = min(int(response.headers['X-Ratelimit-Remaining']), limit - self.outgoing) + + # The checks below combats out-of-order responses and alternating sub ratelimits. + # As a result, there will be lower throughput for routes with subratelimits. + # Unless we document and maintain the undocumented and unstable subratelimits ourselves. + + # 1. Completely ignore if the ratelimit window has expired; that data is useless + if self.http.loop.time() >= reset_at: + return False + + # 2. Always use the longest reset_at value + update = False + if reset_at > self.reset_at: + self.reset_at = reset_at + update = True + + # 3. Always use the lowest remaining value + if remaining < self.remaining: + self.remaining = remaining + update = True + + self.limit = limit + self.one_shot = False + + # Return whether this update was relevant or not. + return update + + def is_inactive(self) -> bool: + delta = self.http.loop.time() - self._last_request + return delta >= 300 and (self.one_shot or (self.outgoing == 0 and self.pending == 0)) + + async def _wait_global(self, start_time: float): + + # Sleep up to 3 times, to account for global reset at overwriting during sleeps + for i in range(3): + seconds = self.http.global_reset_at - start_time + if seconds > 0: + await asyncio.sleep(seconds) + continue + break else: - self.reset_after = float(reset_after) - - self.expires = self._loop.time() + self.reset_after - - def _wake_next(self) -> None: - while self._pending_requests: - future = self._pending_requests.popleft() - if not future.done(): - future.set_result(None) - break - - def _wake(self, count: int = 1, *, exception: Optional[RateLimited] = None) -> None: - awaken = 0 - while self._pending_requests: - future = self._pending_requests.popleft() - if not future.done(): - if exception: - future.set_exception(exception) - else: - future.set_result(None) - awaken += 1 + raise ValueError("Global reset at changed more than 3 times") - if awaken >= count: - break + async def _wait(self): - async def _refresh(self) -> None: - error = self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout - exception = RateLimited(self.reset_after) if error else None - async with self._sleeping: - if not error: - await asyncio.sleep(self.reset_after) + # Consider waiting if none is remaining + if not self.remaining: - self.reset() - self._wake(self.remaining, exception=exception) + # If reset_at is not set yet, wait for the last request, if outgoing, to finish first + # for up to 3 seconds instead of using aiohttp's default 5 min timeout. + if not self.reset_at and (not self._last_request or self.http.loop.time() - self._last_request < 3): + try: + await asyncio.wait_for(self._event.wait(), 3) + except asyncio.TimeoutError: + fmt = 'Initial request for rate limit bucket (%s) never finished. Skipping.' + _log.warning(fmt, self.key) + pass + + # If none are still remaining then start sleeping + if not self.remaining and not self.one_shot: + + # Sleep up to 3 times, giving room for a bucket update and a bucket change + # or 2 sub-ratelimit bucket changes, prioritizing is handled in update() + for i in range(3): + seconds = self.reset_at - self.http.loop.time() + copy = self.reset_at + if seconds > 0: + if i: + fmt = 'A rate limit bucket (%s) has changed reset_at during sleep. Sleeping again for %.2f seconds.' + else: + fmt = 'A rate limit bucket (%s) has been exhausted. Sleeping for %.2f seconds.' + _log.warning(fmt, self.key, seconds) + await asyncio.sleep(seconds) + if copy == self.reset_at: + self.reset() + elif not self.remaining and not self.one_shot: + continue # sleep again + break + else: + raise ValueError("Reset at changed more than 3 times") + + async def acquire(self): + start_time: float = self.http.loop.time() + + # Resources confirmed to be "one shots" after the first request like interaction response + # don't have ratelimits and can skip the entire queue logic except global wait + if self.one_shot: + await self._wait_global(start_time) + else: - def is_expired(self) -> bool: - return self.expires is not None and self._loop.time() > self.expires + # Ensure only 1 request goes through the inner acquire logic at a time + self.pending += 1 + async with self._lock: + await self._wait() + await self._wait_global(start_time) + self.remaining -= 1 # one shot changing this doesn't matter + self.pending -= 1 + self.outgoing += 1 - def is_inactive(self) -> bool: - delta = self._loop.time() - self._last_request - return delta >= 300 and self.outgoing == 0 and len(self._pending_requests) == 0 - - async def acquire(self) -> None: - self._last_request = self._loop.time() - if self.is_expired(): - self.reset() - - if self._max_ratelimit_timeout is not None and self.expires is not None: - # Check if we can pre-emptively block this request for having too large of a timeout - current_reset_after = self.expires - self._loop.time() - if current_reset_after > self._max_ratelimit_timeout: - raise RateLimited(current_reset_after) - - while self.remaining <= 0: - future = self._loop.create_future() - self._pending_requests.append(future) - try: - while not future.done(): - # 30 matches the smallest allowed max_ratelimit_timeout - max_wait_time = self.expires - self._loop.time() if self.expires else 30 - await asyncio.wait([future], timeout=max_wait_time) - if not future.done(): - await self._refresh() - except: - future.cancel() - if self.remaining > 0 and not future.cancelled(): - self._wake_next() - raise - - self.remaining -= 1 - self.outgoing += 1 + self._last_request = self.http.loop.time() + return self + + async def release(self): + if not self.one_shot: + self.outgoing -= 1 + self._event.set() + self._event.clear() async def __aenter__(self) -> Self: await self.acquire() return self async def __aexit__(self, type: Type[BE], value: BE, traceback: TracebackType) -> None: - self.outgoing -= 1 - tokens = self.remaining - self.outgoing - # Check whether the rate limit needs to be pre-emptively slept on - # Note that this is a Lock to prevent multiple rate limit objects from sleeping at once - if not self._sleeping.locked(): - if tokens <= 0: - await self._refresh() - elif self._pending_requests: - exception = ( - RateLimited(self.reset_after) - if self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout - else None - ) - self._wake(tokens, exception=exception) + await self.release() # For some reason, the Discord voice websocket expects this header to be @@ -535,7 +551,7 @@ class HTTPClient: # one shot requests that don't have a bucket hash # When this reaches 256 elements, it will try to evict based off of expiry self._buckets: Dict[str, Ratelimit] = {} - self._global_over: asyncio.Event = MISSING + self.global_reset_at: float = 0.0 self.token: Optional[str] = None self.proxy: Optional[str] = proxy self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth @@ -577,7 +593,7 @@ class HTTPClient: try: value = self._buckets[key] except KeyError: - self._buckets[key] = value = Ratelimit(self.max_ratelimit_timeout) + self._buckets[key] = value = Ratelimit(self, key) self._try_clear_expired_ratelimits() return value @@ -631,14 +647,11 @@ class HTTPClient: if self.proxy_auth is not None: kwargs['proxy_auth'] = self.proxy_auth - if not self._global_over.is_set(): - # wait until the global lock is complete - await self._global_over.wait() - response: Optional[aiohttp.ClientResponse] = None data: Optional[Union[Dict[str, Any], str]] = None - async with ratelimit: - for tries in range(5): + + for tries in range(5): + async with ratelimit: if files: for f in files: f.reset(seek=tries) @@ -654,116 +667,78 @@ class HTTPClient: async with self.__session.request(method, url, **kwargs) as response: _log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status) - # even errors have text involved in them so this is safe to call + # Endpoint has ratelimit headers + new_bucket_hash = response.headers.get('X-Ratelimit-Bucket') + if new_bucket_hash: + + # Ratelimit headers are up to date and relevant + if ratelimit.update(response): + + # Adjust key if the bucket has changed. Either encountered a sub-ratelimit + # or Discord just wants to change ratelimit values for an update probably. + if new_bucket_hash != bucket_hash: + if bucket_hash is not None: + fmt = 'A route (%s) has changed hashes: %s -> %s.' + _log.debug(fmt, route_key, bucket_hash, new_bucket_hash) + self._buckets.pop(key, None) + elif route_key not in self._bucket_hashes: + fmt = '%s has found its initial rate limit bucket hash (%s).' + _log.debug(fmt, route_key, new_bucket_hash) + self._bucket_hashes[route_key] = new_bucket_hash + ratelimit.key = new_bucket_hash + route.major_parameters + self._buckets[ratelimit.key] = ratelimit + + # Global rate limit 429 wont have ratelimit headers (also can't tell if it's one-shot) + elif response.headers.get('X-RateLimit-Global'): + retry_after: float = float(response.headers['Retry-After']) + _log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) + self.global_reset_at = self.loop.time() + retry_after + + # Endpoint does not have ratelimit headers; it's one-shot. + elif not ratelimit.one_shot: + ratelimit.no_headers() + + # Errors have text involved in them so this is safe to call data = await json_or_text(response) - # Update and use rate limit information if the bucket header is present - discord_hash = response.headers.get('X-Ratelimit-Bucket') - # I am unsure if X-Ratelimit-Bucket is always available - # However, X-Ratelimit-Remaining has been a consistent cornerstone that worked - has_ratelimit_headers = 'X-Ratelimit-Remaining' in response.headers - if discord_hash is not None: - # If the hash Discord has provided is somehow different from our current hash something changed - if bucket_hash != discord_hash: - if bucket_hash is not None: - # If the previous hash was an actual Discord hash then this means the - # hash has changed sporadically. - # This can be due to two reasons - # 1. It's a sub-ratelimit which is hard to handle - # 2. The rate limit information genuinely changed - # There is no good way to discern these, Discord doesn't provide a way to do so. - # At best, there will be some form of logging to help catch it. - # Alternating sub-ratelimits means that the requests oscillate between - # different underlying rate limits -- this can lead to unexpected 429s - # It is unavoidable. - fmt = 'A route (%s) has changed hashes: %s -> %s.' - _log.debug(fmt, route_key, bucket_hash, discord_hash) - - self._bucket_hashes[route_key] = discord_hash - self._buckets[f'{discord_hash}:{route.major_parameters}'] = ratelimit - self._buckets.pop(key, None) - elif route_key not in self._bucket_hashes: - fmt = '%s has found its initial rate limit bucket hash (%s).' - _log.debug(fmt, route_key, discord_hash) - self._bucket_hashes[route_key] = discord_hash - self._buckets[f'{discord_hash}:{route.major_parameters}'] = ratelimit - - if has_ratelimit_headers: - if response.status != 429: - ratelimit.update(response, use_clock=self.use_clock) - if ratelimit.remaining == 0: - _log.debug( - 'A rate limit bucket (%s) has been exhausted. Pre-emptively rate limiting...', - discord_hash or route_key, - ) - - # the request was successful so just return the text/json + # The request was successful so just return the text/json if 300 > response.status >= 200: _log.debug('%s %s has received %s', method, url, data) return data - # we are being rate limited - if response.status == 429: - if not response.headers.get('Via') or isinstance(data, str): - # Banned by Cloudflare more than likely. + # We are being ratelimited + elif response.status == 429: + retry_after: float = float(response.headers['Retry-After']) # only in headers for cf ban + + # Hit Cloudflare ban for too many invalid requests (10,000 per 10 minutes) + # An invalid HTTP request is 401, 403, or 429 (excluding "shared" scope). + # This ban is the only one that is IP based, and not just token based. + if not response.headers.get('Via'): + fmt = 'We are Cloudflare banned for %.2f seconds.' + _log.warning(fmt, method, url, retry_after) raise HTTPException(response, data) - if ratelimit.remaining > 0: - # According to night - # https://github.com/discord/discord-api-docs/issues/2190#issuecomment-816363129 - # Remaining > 0 and 429 means that a sub ratelimit was hit. - # It is unclear what should happen in these cases other than just using the retry_after - # value in the body. - _log.debug( - '%s %s received a 429 despite having %s remaining requests. This is a sub-ratelimit.', - method, - url, - ratelimit.remaining, - ) - - retry_after: float = data['retry_after'] - if self.max_ratelimit_timeout and retry_after > self.max_ratelimit_timeout: - _log.warning( - 'We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead.', - method, - url, - retry_after, - ) + elif self.max_ratelimit_timeout and retry_after > self.max_ratelimit_timeout: + fmt = 'We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead.' + _log.warning(fmt, method, url, retry_after) raise RateLimited(retry_after) + elif ratelimit.remaining > 0: + # According to night, remaining > 0 and 429 means that a sub ratelimit was hit. + # https://github.com/discord/discord-api-docs/issues/2190#issuecomment-816363129 + fmt = '%s %s received a 429 despite having %s remaining requests. This is a sub-ratelimit.' + _log.debug(fmt, method, url, ratelimit.remaining) + fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.' _log.warning(fmt, method, url, retry_after) - - _log.debug( - 'Rate limit is being handled by bucket hash %s with %r major parameters', - bucket_hash, - route.major_parameters, - ) - - # check if it's a global rate limit - is_global = data.get('global', False) - if is_global: - _log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) - self._global_over.clear() - - await asyncio.sleep(retry_after) - _log.debug('Done sleeping for the rate limit. Retrying...') - - # release the global lock now that the - # global rate limit has passed - if is_global: - self._global_over.set() - _log.debug('Global rate limit is now over.') - - continue - - # we've received a 500, 502, 504, or 524, unconditional retry - if response.status in {500, 502, 504, 524}: + + # We've received a 500, 502, 504, or 524, unconditional retry + elif response.status in {500, 502, 504, 524}: await asyncio.sleep(1 + tries * 2) continue - # the usual error cases - if response.status == 403: + # The usual error cases + elif response.status == 403: raise Forbidden(response, data) elif response.status == 404: raise NotFound(response, data) @@ -774,19 +749,17 @@ class HTTPClient: # This is handling exceptions from the request except OSError as e: - # Connection reset by peer - if tries < 4 and e.errno in (54, 10054): - await asyncio.sleep(1 + tries * 2) - continue - raise - - if response is not None: - # We've run out of retries, raise. - if response.status >= 500: - raise DiscordServerError(response, data) - - raise HTTPException(response, data) - + if tries == 4 or e.errno not in (54, 10054): + raise ValueError("Connection reset by peer") + retry_seconds: int = 1 + tries * 2 + fmt = 'OS error for %s %s. Retrying in %d seconds.' + _log.warning(fmt, method, url, retry_seconds) + await asyncio.sleep(retry_seconds) + + # We've run out of retries, raise + if response is not None: + raise HTTPException(response, data) + else: raise RuntimeError('Unreachable code in HTTP handling') async def get_from_cdn(self, url: str) -> bytes: