From c17eb31328f95b092bebdbc7e68539ea33c3ab9c Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 18 Jul 2022 08:55:05 -0400 Subject: [PATCH] Rewrite rate limit handling to use X-Ratelimit-Bucket and a semaphore This should increase throughput of the number of requests that can be made at once, while simultaneously following the new standard practice of using the rate limit bucket header. This is an accumulation of a lot of months of work between a few people and it has been tested extensively. From the testing it seems to work fine, but I'm not sure if it's the best way to do it. This changeset does not currently take into consideration sub rate limits yet, but the foundation is there via Route.metadata. In the future, this metadata will be filled in with the known sub rate limit implementation to allow them to have separate keys in the rate limit mapping. Co-authored-by: Josh --- discord/http.py | 244 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 205 insertions(+), 39 deletions(-) diff --git a/discord/http.py b/discord/http.py index 3cfa359ad..7c8a3c758 100644 --- a/discord/http.py +++ b/discord/http.py @@ -46,7 +46,9 @@ from typing import ( Union, ) from urllib.parse import quote as _uriquote +from collections import deque import weakref +import datetime import aiohttp @@ -283,9 +285,12 @@ def _set_api_version(value: int): class Route: BASE: ClassVar[str] = 'https://discord.com/api/v10' - def __init__(self, method: str, path: str, **parameters: Any) -> None: + def __init__(self, method: str, path: str, *, metadata: Optional[str] = None, **parameters: Any) -> None: self.path: str = path self.method: str = method + # Metadata is a special string used to differentiate between known sub rate limits + # Since these can't be handled generically, this is the next best way to do so. + self.metadata: Optional[str] = metadata url = self.BASE + self.path if parameters: url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()}) @@ -298,30 +303,137 @@ class Route: self.webhook_token: Optional[str] = parameters.get('webhook_token') @property - def bucket(self) -> str: - # the bucket is just method + path w/ major parameters - return f'{self.channel_id}:{self.guild_id}:{self.path}' + def key(self) -> str: + """The bucket key is used to represent the route in various mappings.""" + if self.metadata: + return f'{self.method} {self.path}:{self.metadata}' + return f'{self.method} {self.path}' + @property + def major_parameters(self) -> str: + """Returns the major parameters formatted a string. -class MaybeUnlock: - def __init__(self, lock: asyncio.Lock) -> None: - self.lock: asyncio.Lock = lock - self._unlock: bool = True + This needs to be appended to a bucket hash to constitute as a full rate limit key. + """ + return '+'.join( + str(k) for k in (self.channel_id, self.guild_id, self.webhook_id, self.webhook_token) if k is not None + ) - def __enter__(self) -> Self: - return self - def defer(self) -> None: - self._unlock = False +class Ratelimit: + """Represents a Discord rate limit. + + This is similar to a semaphore except tailored to Discord's rate limits. This is aware of + the expiry of a token window, along with the number of tokens available. The goal of this + design is to increase throughput of requests being sent concurrently rather than forcing + everything into a single lock queue per route. + """ + def __init__(self) -> None: + self.limit: int = 1 + self.remaining: int = self.limit + self.outgoing: int = 0 + self.reset_after: float = 0.0 + self.expires: Optional[float] = None + self.dirty: bool = False + self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + self._pending_requests: deque[asyncio.Future[Any]] = deque() + # Only a single rate limit object should be sleeping at a time. + # The object that is sleeping is ultimately responsible for freeing the semaphore + # for the requests currently pending. + self._sleeping: asyncio.Lock = asyncio.Lock() + + def __repr__(self) -> str: + return ( + f'' + ) - def __exit__( - self, - exc_type: Optional[Type[BE]], - exc: Optional[BE], - traceback: Optional[TracebackType], - ) -> None: - if self._unlock: - self.lock.release() + def reset(self): + self.remaining = self.limit - self.outgoing + self.expires = None + self.reset_after = 0.0 + self.dirty = False + + def update(self, response: aiohttp.ClientResponse, *, use_clock: bool = False) -> None: + headers = response.headers + self.limit = int(headers.get('X-Ratelimit-Limit', 1)) + + if self.dirty: + self.remaining = min(int(headers.get('X-Ratelimit-Remaining', 0)), self.limit - self.outgoing) + else: + self.remaining = int(headers.get('X-Ratelimit-Remaining', 0)) + self.dirty = True + + reset_after = headers.get('X-Ratelimit-Reset-After') + if use_clock or not reset_after: + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(headers['X-Ratelimit-Reset']), utc) + self.reset_after = (reset - now).total_seconds() + else: + self.reset_after = float(reset_after) + + self.expires = self._loop.time() + self.reset_after + + def _wake_next(self) -> None: + while self._pending_requests: + future = self._pending_requests.popleft() + if not future.done(): + future.set_result(None) + break + + def _wake(self, count: int = 1) -> None: + awaken = 0 + while self._pending_requests: + future = self._pending_requests.popleft() + if not future.done(): + future.set_result(None) + self._has_just_awaken = True + awaken += 1 + + if awaken >= count: + break + + async def _refresh(self) -> None: + async with self._sleeping: + await asyncio.sleep(self.reset_after) + self.reset() + self._wake(self.remaining) + + def is_expired(self) -> bool: + return self.expires is not None and self._loop.time() > self.expires + + async def acquire(self) -> None: + if self.is_expired(): + self.reset() + + while self.remaining <= 0: + future = self._loop.create_future() + self._pending_requests.append(future) + try: + await future + except: + future.cancel() + if self.remaining > 0 and not future.cancelled(): + self._wake_next() + raise + + self.remaining -= 1 + self.outgoing += 1 + + async def __aenter__(self) -> Self: + await self.acquire() + return self + + async def __aexit__(self, type: Type[BE], value: BE, traceback: TracebackType) -> None: + self.outgoing -= 1 + tokens = self.remaining - self.outgoing + # Check whether the rate limit needs to be pre-emptively slept on + # Note that this is a Lock to prevent multiple rate limit objects from sleeping at once + if not self._sleeping.locked(): + if tokens <= 0: + await self._refresh() + elif self._pending_requests: + self._wake(tokens) # For some reason, the Discord voice websocket expects this header to be @@ -345,7 +457,15 @@ class HTTPClient: self.loop: asyncio.AbstractEventLoop = loop self.connector: aiohttp.BaseConnector = connector or MISSING self.__session: aiohttp.ClientSession = MISSING # filled in static_login - self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + # Route key -> Bucket hash + self._bucket_hashes: Dict[str, str] = {} + # Bucket Hash + Major Parameters -> Rate limit + self._buckets: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary() + # Route key + Major Parameters -> Rate limit + # Used for temporary one shot requests that don't have a bucket hash + # While I'd love to use a single mapping for these, doing this would cause the rate limit objects + # to inexplicably be evicted from the dictionary before we're done with it + self._oneshots: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary() self._global_over: asyncio.Event = MISSING self.token: Optional[str] = None self.proxy: Optional[str] = proxy @@ -383,15 +503,24 @@ class HTTPClient: form: Optional[Iterable[Dict[str, Any]]] = None, **kwargs: Any, ) -> Any: - bucket = route.bucket method = route.method url = route.url + route_key = route.key + + bucket_hash = None + try: + bucket_hash = self._bucket_hashes[route_key] + except KeyError: + key = route_key + route.major_parameters + mapping = self._oneshots + else: + key = bucket_hash + route.major_parameters + mapping = self._buckets - lock = self._locks.get(bucket) - if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock + try: + ratelimit = mapping[key] + except KeyError: + mapping[key] = ratelimit = Ratelimit() # header creation headers: Dict[str, str] = { @@ -427,8 +556,7 @@ class HTTPClient: response: Optional[aiohttp.ClientResponse] = None data: Optional[Union[Dict[str, Any], str]] = None - await lock.acquire() - with MaybeUnlock(lock) as maybe_lock: + async with ratelimit: for tries in range(5): if files: for f in files: @@ -448,14 +576,46 @@ class HTTPClient: # even errors have text involved in them so this is safe to call data = await json_or_text(response) - # check if we have rate limit header information - remaining = response.headers.get('X-Ratelimit-Remaining') - if remaining == '0' and response.status != 429: - # we've depleted our current bucket - delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock) - _log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) - maybe_lock.defer() - self.loop.call_later(delta, lock.release) + # Update and use rate limit information if the bucket header is present + discord_hash = response.headers.get('X-Ratelimit-Bucket') + # I am unsure if X-Ratelimit-Bucket is always available + # However, X-Ratelimit-Remaining has been a consistent cornerstone that worked + has_ratelimit_headers = 'X-Ratelimit-Remaining' in response.headers + if discord_hash is not None: + # If the hash Discord has provided is somehow different from our current hash something changed + if bucket_hash != discord_hash: + if bucket_hash is not None: + # If the previous hash was an actual Discord hash then this means the + # hash has changed sporadically. + # This can be due to two reasons + # 1. It's a sub-ratelimit which is hard to handle + # 2. The rate limit information genuinely changed + # There is no good way to discern these, Discord doesn't provide a way to do so. + # At best, there will be some form of logging to help catch it. + # Alternating sub-ratelimits means that the requests oscillate between + # different underlying rate limits -- this can lead to unexpected 429s + # It is unavoidable. + fmt = 'A route (%s) has changed hashes: %s -> %s.' + _log.debug(fmt, route_key, bucket_hash, discord_hash) + + self._bucket_hashes[route_key] = discord_hash + recalculated_key = discord_hash + route.major_parameters + self._buckets[recalculated_key] = ratelimit + self._buckets.pop(key, None) + elif route_key not in self._bucket_hashes: + fmt = '%s has found its initial rate limit bucket hash (%s).' + _log.debug(fmt, route_key, discord_hash) + self._bucket_hashes[route_key] = discord_hash + self._buckets[discord_hash + route.major_parameters] = ratelimit + + if has_ratelimit_headers: + if response.status != 429: + ratelimit.update(response, use_clock=self.use_clock) + if ratelimit.remaining == 0: + _log.debug( + 'A rate limit bucket (%s) has been exhausted. Pre-emptively rate limiting...', + discord_hash or route_key, + ) # the request was successful so just return the text/json if 300 > response.status >= 200: @@ -468,11 +628,17 @@ class HTTPClient: # Banned by Cloudflare more than likely. raise HTTPException(response, data) - fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' + fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.' # sleep a bit retry_after: float = data['retry_after'] - _log.warning(fmt, retry_after, bucket, stack_info=True) + _log.warning(fmt, method, url, retry_after, stack_info=True) + + _log.debug( + 'Rate limit is being handled by bucket hash %s with %r major parameters', + bucket_hash, + route.major_parameters, + ) # check if it's a global rate limit is_global = data.get('global', False)