diff --git a/discord/http.py b/discord/http.py index 3cfa359ad..7c8a3c758 100644 --- a/discord/http.py +++ b/discord/http.py @@ -46,7 +46,9 @@ from typing import ( Union, ) from urllib.parse import quote as _uriquote +from collections import deque import weakref +import datetime import aiohttp @@ -283,9 +285,12 @@ def _set_api_version(value: int): class Route: BASE: ClassVar[str] = 'https://discord.com/api/v10' - def __init__(self, method: str, path: str, **parameters: Any) -> None: + def __init__(self, method: str, path: str, *, metadata: Optional[str] = None, **parameters: Any) -> None: self.path: str = path self.method: str = method + # Metadata is a special string used to differentiate between known sub rate limits + # Since these can't be handled generically, this is the next best way to do so. + self.metadata: Optional[str] = metadata url = self.BASE + self.path if parameters: url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()}) @@ -298,30 +303,137 @@ class Route: self.webhook_token: Optional[str] = parameters.get('webhook_token') @property - def bucket(self) -> str: - # the bucket is just method + path w/ major parameters - return f'{self.channel_id}:{self.guild_id}:{self.path}' + def key(self) -> str: + """The bucket key is used to represent the route in various mappings.""" + if self.metadata: + return f'{self.method} {self.path}:{self.metadata}' + return f'{self.method} {self.path}' + @property + def major_parameters(self) -> str: + """Returns the major parameters formatted a string. -class MaybeUnlock: - def __init__(self, lock: asyncio.Lock) -> None: - self.lock: asyncio.Lock = lock - self._unlock: bool = True + This needs to be appended to a bucket hash to constitute as a full rate limit key. + """ + return '+'.join( + str(k) for k in (self.channel_id, self.guild_id, self.webhook_id, self.webhook_token) if k is not None + ) - def __enter__(self) -> Self: - return self - def defer(self) -> None: - self._unlock = False +class Ratelimit: + """Represents a Discord rate limit. + + This is similar to a semaphore except tailored to Discord's rate limits. This is aware of + the expiry of a token window, along with the number of tokens available. The goal of this + design is to increase throughput of requests being sent concurrently rather than forcing + everything into a single lock queue per route. + """ + def __init__(self) -> None: + self.limit: int = 1 + self.remaining: int = self.limit + self.outgoing: int = 0 + self.reset_after: float = 0.0 + self.expires: Optional[float] = None + self.dirty: bool = False + self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + self._pending_requests: deque[asyncio.Future[Any]] = deque() + # Only a single rate limit object should be sleeping at a time. + # The object that is sleeping is ultimately responsible for freeing the semaphore + # for the requests currently pending. + self._sleeping: asyncio.Lock = asyncio.Lock() + + def __repr__(self) -> str: + return ( + f'' + ) - def __exit__( - self, - exc_type: Optional[Type[BE]], - exc: Optional[BE], - traceback: Optional[TracebackType], - ) -> None: - if self._unlock: - self.lock.release() + def reset(self): + self.remaining = self.limit - self.outgoing + self.expires = None + self.reset_after = 0.0 + self.dirty = False + + def update(self, response: aiohttp.ClientResponse, *, use_clock: bool = False) -> None: + headers = response.headers + self.limit = int(headers.get('X-Ratelimit-Limit', 1)) + + if self.dirty: + self.remaining = min(int(headers.get('X-Ratelimit-Remaining', 0)), self.limit - self.outgoing) + else: + self.remaining = int(headers.get('X-Ratelimit-Remaining', 0)) + self.dirty = True + + reset_after = headers.get('X-Ratelimit-Reset-After') + if use_clock or not reset_after: + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(headers['X-Ratelimit-Reset']), utc) + self.reset_after = (reset - now).total_seconds() + else: + self.reset_after = float(reset_after) + + self.expires = self._loop.time() + self.reset_after + + def _wake_next(self) -> None: + while self._pending_requests: + future = self._pending_requests.popleft() + if not future.done(): + future.set_result(None) + break + + def _wake(self, count: int = 1) -> None: + awaken = 0 + while self._pending_requests: + future = self._pending_requests.popleft() + if not future.done(): + future.set_result(None) + self._has_just_awaken = True + awaken += 1 + + if awaken >= count: + break + + async def _refresh(self) -> None: + async with self._sleeping: + await asyncio.sleep(self.reset_after) + self.reset() + self._wake(self.remaining) + + def is_expired(self) -> bool: + return self.expires is not None and self._loop.time() > self.expires + + async def acquire(self) -> None: + if self.is_expired(): + self.reset() + + while self.remaining <= 0: + future = self._loop.create_future() + self._pending_requests.append(future) + try: + await future + except: + future.cancel() + if self.remaining > 0 and not future.cancelled(): + self._wake_next() + raise + + self.remaining -= 1 + self.outgoing += 1 + + async def __aenter__(self) -> Self: + await self.acquire() + return self + + async def __aexit__(self, type: Type[BE], value: BE, traceback: TracebackType) -> None: + self.outgoing -= 1 + tokens = self.remaining - self.outgoing + # Check whether the rate limit needs to be pre-emptively slept on + # Note that this is a Lock to prevent multiple rate limit objects from sleeping at once + if not self._sleeping.locked(): + if tokens <= 0: + await self._refresh() + elif self._pending_requests: + self._wake(tokens) # For some reason, the Discord voice websocket expects this header to be @@ -345,7 +457,15 @@ class HTTPClient: self.loop: asyncio.AbstractEventLoop = loop self.connector: aiohttp.BaseConnector = connector or MISSING self.__session: aiohttp.ClientSession = MISSING # filled in static_login - self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + # Route key -> Bucket hash + self._bucket_hashes: Dict[str, str] = {} + # Bucket Hash + Major Parameters -> Rate limit + self._buckets: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary() + # Route key + Major Parameters -> Rate limit + # Used for temporary one shot requests that don't have a bucket hash + # While I'd love to use a single mapping for these, doing this would cause the rate limit objects + # to inexplicably be evicted from the dictionary before we're done with it + self._oneshots: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary() self._global_over: asyncio.Event = MISSING self.token: Optional[str] = None self.proxy: Optional[str] = proxy @@ -383,15 +503,24 @@ class HTTPClient: form: Optional[Iterable[Dict[str, Any]]] = None, **kwargs: Any, ) -> Any: - bucket = route.bucket method = route.method url = route.url + route_key = route.key + + bucket_hash = None + try: + bucket_hash = self._bucket_hashes[route_key] + except KeyError: + key = route_key + route.major_parameters + mapping = self._oneshots + else: + key = bucket_hash + route.major_parameters + mapping = self._buckets - lock = self._locks.get(bucket) - if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock + try: + ratelimit = mapping[key] + except KeyError: + mapping[key] = ratelimit = Ratelimit() # header creation headers: Dict[str, str] = { @@ -427,8 +556,7 @@ class HTTPClient: response: Optional[aiohttp.ClientResponse] = None data: Optional[Union[Dict[str, Any], str]] = None - await lock.acquire() - with MaybeUnlock(lock) as maybe_lock: + async with ratelimit: for tries in range(5): if files: for f in files: @@ -448,14 +576,46 @@ class HTTPClient: # even errors have text involved in them so this is safe to call data = await json_or_text(response) - # check if we have rate limit header information - remaining = response.headers.get('X-Ratelimit-Remaining') - if remaining == '0' and response.status != 429: - # we've depleted our current bucket - delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock) - _log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) - maybe_lock.defer() - self.loop.call_later(delta, lock.release) + # Update and use rate limit information if the bucket header is present + discord_hash = response.headers.get('X-Ratelimit-Bucket') + # I am unsure if X-Ratelimit-Bucket is always available + # However, X-Ratelimit-Remaining has been a consistent cornerstone that worked + has_ratelimit_headers = 'X-Ratelimit-Remaining' in response.headers + if discord_hash is not None: + # If the hash Discord has provided is somehow different from our current hash something changed + if bucket_hash != discord_hash: + if bucket_hash is not None: + # If the previous hash was an actual Discord hash then this means the + # hash has changed sporadically. + # This can be due to two reasons + # 1. It's a sub-ratelimit which is hard to handle + # 2. The rate limit information genuinely changed + # There is no good way to discern these, Discord doesn't provide a way to do so. + # At best, there will be some form of logging to help catch it. + # Alternating sub-ratelimits means that the requests oscillate between + # different underlying rate limits -- this can lead to unexpected 429s + # It is unavoidable. + fmt = 'A route (%s) has changed hashes: %s -> %s.' + _log.debug(fmt, route_key, bucket_hash, discord_hash) + + self._bucket_hashes[route_key] = discord_hash + recalculated_key = discord_hash + route.major_parameters + self._buckets[recalculated_key] = ratelimit + self._buckets.pop(key, None) + elif route_key not in self._bucket_hashes: + fmt = '%s has found its initial rate limit bucket hash (%s).' + _log.debug(fmt, route_key, discord_hash) + self._bucket_hashes[route_key] = discord_hash + self._buckets[discord_hash + route.major_parameters] = ratelimit + + if has_ratelimit_headers: + if response.status != 429: + ratelimit.update(response, use_clock=self.use_clock) + if ratelimit.remaining == 0: + _log.debug( + 'A rate limit bucket (%s) has been exhausted. Pre-emptively rate limiting...', + discord_hash or route_key, + ) # the request was successful so just return the text/json if 300 > response.status >= 200: @@ -468,11 +628,17 @@ class HTTPClient: # Banned by Cloudflare more than likely. raise HTTPException(response, data) - fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' + fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.' # sleep a bit retry_after: float = data['retry_after'] - _log.warning(fmt, retry_after, bucket, stack_info=True) + _log.warning(fmt, method, url, retry_after, stack_info=True) + + _log.debug( + 'Rate limit is being handled by bucket hash %s with %r major parameters', + bucket_hash, + route.major_parameters, + ) # check if it's a global rate limit is_global = data.get('global', False)