Browse Source

Rewrite ratelimit class

pull/10287/head
imp 2 days ago
committed by GitHub
parent
commit
c0181e9448
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 445
      discord/http.py

445
discord/http.py

@ -343,165 +343,179 @@ class Route:
str(k) for k in (self.channel_id, self.guild_id, self.webhook_id, self.webhook_token) if k is not None str(k) for k in (self.channel_id, self.guild_id, self.webhook_id, self.webhook_token) if k is not None
) )
class Ratelimit: class Ratelimit:
"""Represents a Discord rate limit. """Represents a Discord rate limit.
This is similar to a semaphore except tailored to Discord's rate limits. This is aware of This is similar to a semaphore except tailored to Discord's rate limits. This is aware of
the expiry of a token window, along with the number of tokens available. The goal of this the expiry of a token window, along with the number of tokens available.
design is to increase throughput of requests being sent concurrently rather than forcing
everything into a single lock queue per route.
""" """
__slots__ = ( __slots__ = (
'key',
'limit', 'limit',
'remaining', 'remaining',
'reset_at',
'pending',
'outgoing', 'outgoing',
'reset_after', 'one_shot',
'expires', 'http',
'dirty',
'_last_request', '_last_request',
'_max_ratelimit_timeout', '_lock',
'_loop', '_event',
'_pending_requests',
'_sleeping',
) )
def __init__(self, max_ratelimit_timeout: Optional[float]) -> None: def __init__(self, http: HTTPClient, key: str) -> None:
self.key = key
self.limit: int = 1 self.limit: int = 1
self.remaining: int = self.limit self.remaining: int = self.limit
self.reset_at: float = 0.0
self.pending: int = 0
self.outgoing: int = 0 self.outgoing: int = 0
self.reset_after: float = 0.0 self.one_shot: bool = False
self.expires: Optional[float] = None self.http: HTTPClient = http
self.dirty: bool = False self._last_request: float = 0.0
self._max_ratelimit_timeout: Optional[float] = max_ratelimit_timeout self._lock: asyncio.Lock = asyncio.Lock()
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._event: asyncio.Event = asyncio.Event()
self._pending_requests: deque[asyncio.Future[Any]] = deque()
# Only a single rate limit object should be sleeping at a time.
# The object that is sleeping is ultimately responsible for freeing the semaphore
# for the requests currently pending.
self._sleeping: asyncio.Lock = asyncio.Lock()
self._last_request: float = self._loop.time()
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<RateLimitBucket limit={self.limit} remaining={self.remaining} pending_requests={len(self._pending_requests)}>' f'<RateLimitBucket limit={self.limit} remaining={self.remaining} pending={self.pending}>'
) )
def reset(self): def no_headers(self) -> None:
self.one_shot = True
self._event.set()
self._event.clear()
def reset(self) -> None:
self.remaining = self.limit - self.outgoing self.remaining = self.limit - self.outgoing
self.expires = None self.reset_at = 0.0
self.reset_after = 0.0
self.dirty = False
def update(self, response: aiohttp.ClientResponse, *, use_clock: bool = False) -> None: def update(self, response: aiohttp.ClientResponse) -> None:
headers = response.headers
self.limit = int(headers.get('X-Ratelimit-Limit', 1))
if self.dirty: # Shared scope 429 has longer "reset_at", determined using the retry-after field
self.remaining = min(int(headers.get('X-Ratelimit-Remaining', 0)), self.limit - self.outgoing) limit = int(response.headers['X-Ratelimit-Limit'])
if response.headers.get('X-RateLimit-Scope') == 'shared':
reset_at = self.http.loop.time() + float(response.headers['Retry-After'])
remaining = 0
else: else:
self.remaining = int(headers.get('X-Ratelimit-Remaining', 0))
self.dirty = True # Consider a lower remaining value because updates can be out of order, so self.outgoing is used
reset_at = self.http.loop.time() + (float(response.headers['X-Ratelimit-Reset']) - time.time())
reset_after = headers.get('X-Ratelimit-Reset-After') remaining = min(int(response.headers['X-Ratelimit-Remaining']), limit - self.outgoing)
if use_clock or not reset_after:
utc = datetime.timezone.utc # The checks below combats out-of-order responses and alternating sub ratelimits.
now = datetime.datetime.now(utc) # As a result, there will be lower throughput for routes with subratelimits.
reset = datetime.datetime.fromtimestamp(float(headers['X-Ratelimit-Reset']), utc) # Unless we document and maintain the undocumented and unstable subratelimits ourselves.
self.reset_after = (reset - now).total_seconds()
# 1. Completely ignore if the ratelimit window has expired; that data is useless
if self.http.loop.time() >= reset_at:
return
# 2. Always use the longest reset_at value
update = False
if reset_at > self.reset_at:
self.reset_at = reset_at
update = True
# 3. Always use the lowest remaining value
if remaining < self.remaining:
self.remaining = remaining
update = True
self.limit = limit
self.one_shot = False
# Return whether this update was relevant or not.
return update
def is_inactive(self) -> bool:
delta = self.http.loop.time() - self._last_request
return delta >= 300 and (self.one_shot or (self.outgoing == 0 and self.pending == 0))
async def _wait_global(self, start_time: float):
# Sleep up to 3 times, to account for global reset at overwriting during sleeps
for i in range(3):
seconds = self.http.global_reset_at - start_time
if seconds > 0:
await asyncio.sleep(seconds)
continue
break
else: else:
self.reset_after = float(reset_after) raise ValueError("Global reset at changed more than 3 times")
self.expires = self._loop.time() + self.reset_after
def _wake_next(self) -> None:
while self._pending_requests:
future = self._pending_requests.popleft()
if not future.done():
future.set_result(None)
break
def _wake(self, count: int = 1, *, exception: Optional[RateLimited] = None) -> None:
awaken = 0
while self._pending_requests:
future = self._pending_requests.popleft()
if not future.done():
if exception:
future.set_exception(exception)
else:
future.set_result(None)
awaken += 1
if awaken >= count: async def _wait(self):
break
async def _refresh(self) -> None: # Consider waiting if none is remaining
error = self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout if not self.remaining:
exception = RateLimited(self.reset_after) if error else None
async with self._sleeping: # If reset_at is not set yet, wait for the last request, if outgoing, to finish first
if not error: # for up to 3 seconds instead of using aiohttp's default 5 min timeout.
await asyncio.sleep(self.reset_after) if not self.reset_at and (not self._last_request or self.http.loop.time() - self._last_request < 3):
try:
await asyncio.wait_for(self._event.wait(), 3)
except asyncio.TimeoutError:
fmt = 'Initial request for rate limit bucket (%s) never finished. Skipping.'
_log.warning(fmt, self.key)
pass
# If none are still remaining then start sleeping
if not self.remaining and not self.one_shot:
# Sleep up to 3 times, giving room for a bucket update and a bucket change
# or 2 sub-ratelimit bucket changes, prioritizing is handled in update()
for i in range(3):
seconds = self.reset_at - self.http.loop.time()
copy = self.reset_at
if seconds > 0:
if i:
fmt = 'A rate limit bucket (%s) has changed reset_at during sleep. Sleeping again for %.2f seconds.'
else:
fmt = 'A rate limit bucket (%s) has been exhausted. Sleeping for %.2f seconds.'
_log.warning(fmt, self.key, seconds)
await asyncio.sleep(seconds)
if copy == self.reset_at:
self.reset()
elif not self.remaining and not self.one_shot:
continue # sleep again
break
else:
raise ValueError("Reset at changed more than 3 times")
async def acquire(self):
start_time: float = self.http.loop.time()
# Resources confirmed to be "one shots" after the first request like interaction response
# don't have ratelimits and can skip the entire queue logic except global wait
if self.one_shot:
await self._wait_global(start_time)
else:
self.reset() # Ensure only 1 request goes through the inner acquire logic at a time
self._wake(self.remaining, exception=exception) self.pending += 1
async with self._lock:
await self._wait_global(start_time)
await self._wait()
self.remaining -= 1 # one shot changing this doesn't matter
self.pending -= 1
self.outgoing += 1
def is_expired(self) -> bool: self._last_request = self.http.loop.time()
return self.expires is not None and self._loop.time() > self.expires return self
def is_inactive(self) -> bool: async def release(self):
delta = self._loop.time() - self._last_request if not self.one_shot:
return delta >= 300 and self.outgoing == 0 and len(self._pending_requests) == 0 self.outgoing -= 1
self._event.set()
async def acquire(self) -> None: self._event.clear()
self._last_request = self._loop.time()
if self.is_expired():
self.reset()
if self._max_ratelimit_timeout is not None and self.expires is not None:
# Check if we can pre-emptively block this request for having too large of a timeout
current_reset_after = self.expires - self._loop.time()
if current_reset_after > self._max_ratelimit_timeout:
raise RateLimited(current_reset_after)
while self.remaining <= 0:
future = self._loop.create_future()
self._pending_requests.append(future)
try:
while not future.done():
# 30 matches the smallest allowed max_ratelimit_timeout
max_wait_time = self.expires - self._loop.time() if self.expires else 30
await asyncio.wait([future], timeout=max_wait_time)
if not future.done():
await self._refresh()
except:
future.cancel()
if self.remaining > 0 and not future.cancelled():
self._wake_next()
raise
self.remaining -= 1
self.outgoing += 1
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
await self.acquire() await self.acquire()
return self return self
async def __aexit__(self, type: Type[BE], value: BE, traceback: TracebackType) -> None: async def __aexit__(self, type: Type[BE], value: BE, traceback: TracebackType) -> None:
self.outgoing -= 1 await self.release()
tokens = self.remaining - self.outgoing
# Check whether the rate limit needs to be pre-emptively slept on
# Note that this is a Lock to prevent multiple rate limit objects from sleeping at once
if not self._sleeping.locked():
if tokens <= 0:
await self._refresh()
elif self._pending_requests:
exception = (
RateLimited(self.reset_after)
if self._max_ratelimit_timeout and self.reset_after > self._max_ratelimit_timeout
else None
)
self._wake(tokens, exception=exception)
# For some reason, the Discord voice websocket expects this header to be # For some reason, the Discord voice websocket expects this header to be
@ -535,7 +549,7 @@ class HTTPClient:
# one shot requests that don't have a bucket hash # one shot requests that don't have a bucket hash
# When this reaches 256 elements, it will try to evict based off of expiry # When this reaches 256 elements, it will try to evict based off of expiry
self._buckets: Dict[str, Ratelimit] = {} self._buckets: Dict[str, Ratelimit] = {}
self._global_over: asyncio.Event = MISSING self.global_reset_at: float = 0.0
self.token: Optional[str] = None self.token: Optional[str] = None
self.proxy: Optional[str] = proxy self.proxy: Optional[str] = proxy
self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth
@ -577,7 +591,7 @@ class HTTPClient:
try: try:
value = self._buckets[key] value = self._buckets[key]
except KeyError: except KeyError:
self._buckets[key] = value = Ratelimit(self.max_ratelimit_timeout) self._buckets[key] = value = Ratelimit(self, key)
self._try_clear_expired_ratelimits() self._try_clear_expired_ratelimits()
return value return value
@ -631,14 +645,11 @@ class HTTPClient:
if self.proxy_auth is not None: if self.proxy_auth is not None:
kwargs['proxy_auth'] = self.proxy_auth kwargs['proxy_auth'] = self.proxy_auth
if not self._global_over.is_set():
# wait until the global lock is complete
await self._global_over.wait()
response: Optional[aiohttp.ClientResponse] = None response: Optional[aiohttp.ClientResponse] = None
data: Optional[Union[Dict[str, Any], str]] = None data: Optional[Union[Dict[str, Any], str]] = None
async with ratelimit:
for tries in range(5): for tries in range(5):
async with ratelimit:
if files: if files:
for f in files: for f in files:
f.reset(seek=tries) f.reset(seek=tries)
@ -654,116 +665,76 @@ class HTTPClient:
async with self.__session.request(method, url, **kwargs) as response: async with self.__session.request(method, url, **kwargs) as response:
_log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status) _log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status)
# even errors have text involved in them so this is safe to call # Endpoint has ratelimit headers
new_bucket_hash = response.headers.get('X-Ratelimit-Bucket')
if new_bucket_hash:
# Ratelimit headers are up to date and relevant
if ratelimit.update(response):
# Adjust key if the bucket has changed. Either encountered a sub-ratelimit
# or Discord just wants to change ratelimit values for an update probably.
if new_bucket_hash != bucket_hash:
if bucket_hash is not None:
fmt = 'A route (%s) has changed hashes: %s -> %s.'
_log.debug(fmt, route_key, bucket_hash, new_bucket_hash)
self._buckets.pop(key, None)
elif route_key not in self._bucket_hashes:
fmt = '%s has found its initial rate limit bucket hash (%s).'
_log.debug(fmt, route_key, new_bucket_hash)
self._bucket_hashes[route_key] = new_bucket_hash
ratelimit.key = new_bucket_hash + route.major_parameters
self._buckets[ratelimit.key] = ratelimit
# Global rate limit 429 wont have ratelimit headers (also can't tell if it's one-shot)
elif response.headers.get('X-RateLimit-Global'):
self.global_reset_at = self.loop.time() + float(response.headers['Retry-After'])
# Endpoint does not have ratelimit headers; it's one-shot.
else:
ratelimit.no_headers()
# Errors have text involved in them so this is safe to call
data = await json_or_text(response) data = await json_or_text(response)
# Update and use rate limit information if the bucket header is present # The request was successful so just return the text/json
discord_hash = response.headers.get('X-Ratelimit-Bucket')
# I am unsure if X-Ratelimit-Bucket is always available
# However, X-Ratelimit-Remaining has been a consistent cornerstone that worked
has_ratelimit_headers = 'X-Ratelimit-Remaining' in response.headers
if discord_hash is not None:
# If the hash Discord has provided is somehow different from our current hash something changed
if bucket_hash != discord_hash:
if bucket_hash is not None:
# If the previous hash was an actual Discord hash then this means the
# hash has changed sporadically.
# This can be due to two reasons
# 1. It's a sub-ratelimit which is hard to handle
# 2. The rate limit information genuinely changed
# There is no good way to discern these, Discord doesn't provide a way to do so.
# At best, there will be some form of logging to help catch it.
# Alternating sub-ratelimits means that the requests oscillate between
# different underlying rate limits -- this can lead to unexpected 429s
# It is unavoidable.
fmt = 'A route (%s) has changed hashes: %s -> %s.'
_log.debug(fmt, route_key, bucket_hash, discord_hash)
self._bucket_hashes[route_key] = discord_hash
self._buckets[f'{discord_hash}:{route.major_parameters}'] = ratelimit
self._buckets.pop(key, None)
elif route_key not in self._bucket_hashes:
fmt = '%s has found its initial rate limit bucket hash (%s).'
_log.debug(fmt, route_key, discord_hash)
self._bucket_hashes[route_key] = discord_hash
self._buckets[f'{discord_hash}:{route.major_parameters}'] = ratelimit
if has_ratelimit_headers:
if response.status != 429:
ratelimit.update(response, use_clock=self.use_clock)
if ratelimit.remaining == 0:
_log.debug(
'A rate limit bucket (%s) has been exhausted. Pre-emptively rate limiting...',
discord_hash or route_key,
)
# the request was successful so just return the text/json
if 300 > response.status >= 200: if 300 > response.status >= 200:
_log.debug('%s %s has received %s', method, url, data) _log.debug('%s %s has received %s', method, url, data)
return data return data
# we are being rate limited # We are being ratelimited
if response.status == 429: elif response.status == 429:
if not response.headers.get('Via') or isinstance(data, str): retry_after: float = float(response.headers['Retry-After']) # only in headers for cf ban
# Banned by Cloudflare more than likely.
# Hit Cloudflare ban for too many invalid requests (10,000 per 10 minutes)
# An invalid HTTP request is 401, 403, or 429 (excluding "shared" scope).
# This ban is the only one that is IP based, and not just token based.
if not response.headers.get('Via'):
fmt = 'We are Cloudflare banned for %.2f seconds.'
_log.warning(fmt, method, url, retry_after)
raise HTTPException(response, data) raise HTTPException(response, data)
if ratelimit.remaining > 0: elif self.max_ratelimit_timeout and retry_after > self.max_ratelimit_timeout:
# According to night fmt = 'We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead.'
# https://github.com/discord/discord-api-docs/issues/2190#issuecomment-816363129 _log.warning(fmt, method, url, retry_after)
# Remaining > 0 and 429 means that a sub ratelimit was hit.
# It is unclear what should happen in these cases other than just using the retry_after
# value in the body.
_log.debug(
'%s %s received a 429 despite having %s remaining requests. This is a sub-ratelimit.',
method,
url,
ratelimit.remaining,
)
retry_after: float = data['retry_after']
if self.max_ratelimit_timeout and retry_after > self.max_ratelimit_timeout:
_log.warning(
'We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead.',
method,
url,
retry_after,
)
raise RateLimited(retry_after) raise RateLimited(retry_after)
elif ratelimit.remaining > 0:
# According to night, remaining > 0 and 429 means that a sub ratelimit was hit.
# https://github.com/discord/discord-api-docs/issues/2190#issuecomment-816363129
fmt = '%s %s received a 429 despite having %s remaining requests. This is a sub-ratelimit.'
_log.debug(fmt, method, url, ratelimit.remaining)
fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.' fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.'
_log.warning(fmt, method, url, retry_after) _log.warning(fmt, method, url, retry_after)
_log.debug( # We've received a 500, 502, 504, or 524, unconditional retry
'Rate limit is being handled by bucket hash %s with %r major parameters', elif response.status in {500, 502, 504, 524}:
bucket_hash,
route.major_parameters,
)
# check if it's a global rate limit
is_global = data.get('global', False)
if is_global:
_log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after)
self._global_over.clear()
await asyncio.sleep(retry_after)
_log.debug('Done sleeping for the rate limit. Retrying...')
# release the global lock now that the
# global rate limit has passed
if is_global:
self._global_over.set()
_log.debug('Global rate limit is now over.')
continue
# we've received a 500, 502, 504, or 524, unconditional retry
if response.status in {500, 502, 504, 524}:
await asyncio.sleep(1 + tries * 2) await asyncio.sleep(1 + tries * 2)
continue continue
# the usual error cases # The usual error cases
if response.status == 403: elif response.status == 403:
raise Forbidden(response, data) raise Forbidden(response, data)
elif response.status == 404: elif response.status == 404:
raise NotFound(response, data) raise NotFound(response, data)
@ -774,19 +745,17 @@ class HTTPClient:
# This is handling exceptions from the request # This is handling exceptions from the request
except OSError as e: except OSError as e:
# Connection reset by peer if tries == 4 or e.errno not in (54, 10054):
if tries < 4 and e.errno in (54, 10054): raise ValueError("Connection reset by peer")
await asyncio.sleep(1 + tries * 2) retry_after: int = 1 + tries * 2
continue fmt = 'OS error for %s %s. Retrying in %.2f seconds.'
raise _log.warning(fmt, method, url, retry_after)
await asyncio.sleep(retry_after)
if response is not None:
# We've run out of retries, raise. # We've run out of retries, raise
if response.status >= 500: if response is not None:
raise DiscordServerError(response, data) raise HTTPException(response, data)
else:
raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling') raise RuntimeError('Unreachable code in HTTP handling')
async def get_from_cdn(self, url: str) -> bytes: async def get_from_cdn(self, url: str) -> bytes:

Loading…
Cancel
Save