Browse Source

Rewrite rate limit handling to use X-Ratelimit-Bucket and a semaphore

This should increase throughput of the number of requests that can be
made at once, while simultaneously following the new standard practice
of using the rate limit bucket header.

This is an accumulation of a lot of months of work between a few people
and it has been tested extensively. From the testing it seems to work
fine, but I'm not sure if it's the best way to do it.

This changeset does not currently take into consideration sub
rate limits yet, but the foundation is there via Route.metadata. In the
future, this metadata will be filled in with the known sub rate limit
implementation to allow them to have separate keys in the rate limit
mapping.

Co-authored-by: Josh <josh.ja.butt@gmail.com>
pull/8258/head
Rapptz 3 years ago
parent
commit
c17eb31328
  1. 244
      discord/http.py

244
discord/http.py

@ -46,7 +46,9 @@ from typing import (
Union,
)
from urllib.parse import quote as _uriquote
from collections import deque
import weakref
import datetime
import aiohttp
@ -283,9 +285,12 @@ def _set_api_version(value: int):
class Route:
BASE: ClassVar[str] = 'https://discord.com/api/v10'
def __init__(self, method: str, path: str, **parameters: Any) -> None:
def __init__(self, method: str, path: str, *, metadata: Optional[str] = None, **parameters: Any) -> None:
self.path: str = path
self.method: str = method
# Metadata is a special string used to differentiate between known sub rate limits
# Since these can't be handled generically, this is the next best way to do so.
self.metadata: Optional[str] = metadata
url = self.BASE + self.path
if parameters:
url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()})
@ -298,30 +303,137 @@ class Route:
self.webhook_token: Optional[str] = parameters.get('webhook_token')
@property
def bucket(self) -> str:
# the bucket is just method + path w/ major parameters
return f'{self.channel_id}:{self.guild_id}:{self.path}'
def key(self) -> str:
"""The bucket key is used to represent the route in various mappings."""
if self.metadata:
return f'{self.method} {self.path}:{self.metadata}'
return f'{self.method} {self.path}'
@property
def major_parameters(self) -> str:
"""Returns the major parameters formatted a string.
class MaybeUnlock:
def __init__(self, lock: asyncio.Lock) -> None:
self.lock: asyncio.Lock = lock
self._unlock: bool = True
This needs to be appended to a bucket hash to constitute as a full rate limit key.
"""
return '+'.join(
str(k) for k in (self.channel_id, self.guild_id, self.webhook_id, self.webhook_token) if k is not None
)
def __enter__(self) -> Self:
return self
def defer(self) -> None:
self._unlock = False
class Ratelimit:
"""Represents a Discord rate limit.
This is similar to a semaphore except tailored to Discord's rate limits. This is aware of
the expiry of a token window, along with the number of tokens available. The goal of this
design is to increase throughput of requests being sent concurrently rather than forcing
everything into a single lock queue per route.
"""
def __init__(self) -> None:
self.limit: int = 1
self.remaining: int = self.limit
self.outgoing: int = 0
self.reset_after: float = 0.0
self.expires: Optional[float] = None
self.dirty: bool = False
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._pending_requests: deque[asyncio.Future[Any]] = deque()
# Only a single rate limit object should be sleeping at a time.
# The object that is sleeping is ultimately responsible for freeing the semaphore
# for the requests currently pending.
self._sleeping: asyncio.Lock = asyncio.Lock()
def __repr__(self) -> str:
return (
f'<RateLimitBucket limit={self.limit} remaining={self.remaining} pending_requests={len(self._pending_requests)}>'
)
def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self._unlock:
self.lock.release()
def reset(self):
self.remaining = self.limit - self.outgoing
self.expires = None
self.reset_after = 0.0
self.dirty = False
def update(self, response: aiohttp.ClientResponse, *, use_clock: bool = False) -> None:
headers = response.headers
self.limit = int(headers.get('X-Ratelimit-Limit', 1))
if self.dirty:
self.remaining = min(int(headers.get('X-Ratelimit-Remaining', 0)), self.limit - self.outgoing)
else:
self.remaining = int(headers.get('X-Ratelimit-Remaining', 0))
self.dirty = True
reset_after = headers.get('X-Ratelimit-Reset-After')
if use_clock or not reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(headers['X-Ratelimit-Reset']), utc)
self.reset_after = (reset - now).total_seconds()
else:
self.reset_after = float(reset_after)
self.expires = self._loop.time() + self.reset_after
def _wake_next(self) -> None:
while self._pending_requests:
future = self._pending_requests.popleft()
if not future.done():
future.set_result(None)
break
def _wake(self, count: int = 1) -> None:
awaken = 0
while self._pending_requests:
future = self._pending_requests.popleft()
if not future.done():
future.set_result(None)
self._has_just_awaken = True
awaken += 1
if awaken >= count:
break
async def _refresh(self) -> None:
async with self._sleeping:
await asyncio.sleep(self.reset_after)
self.reset()
self._wake(self.remaining)
def is_expired(self) -> bool:
return self.expires is not None and self._loop.time() > self.expires
async def acquire(self) -> None:
if self.is_expired():
self.reset()
while self.remaining <= 0:
future = self._loop.create_future()
self._pending_requests.append(future)
try:
await future
except:
future.cancel()
if self.remaining > 0 and not future.cancelled():
self._wake_next()
raise
self.remaining -= 1
self.outgoing += 1
async def __aenter__(self) -> Self:
await self.acquire()
return self
async def __aexit__(self, type: Type[BE], value: BE, traceback: TracebackType) -> None:
self.outgoing -= 1
tokens = self.remaining - self.outgoing
# Check whether the rate limit needs to be pre-emptively slept on
# Note that this is a Lock to prevent multiple rate limit objects from sleeping at once
if not self._sleeping.locked():
if tokens <= 0:
await self._refresh()
elif self._pending_requests:
self._wake(tokens)
# For some reason, the Discord voice websocket expects this header to be
@ -345,7 +457,15 @@ class HTTPClient:
self.loop: asyncio.AbstractEventLoop = loop
self.connector: aiohttp.BaseConnector = connector or MISSING
self.__session: aiohttp.ClientSession = MISSING # filled in static_login
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
# Route key -> Bucket hash
self._bucket_hashes: Dict[str, str] = {}
# Bucket Hash + Major Parameters -> Rate limit
self._buckets: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary()
# Route key + Major Parameters -> Rate limit
# Used for temporary one shot requests that don't have a bucket hash
# While I'd love to use a single mapping for these, doing this would cause the rate limit objects
# to inexplicably be evicted from the dictionary before we're done with it
self._oneshots: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary()
self._global_over: asyncio.Event = MISSING
self.token: Optional[str] = None
self.proxy: Optional[str] = proxy
@ -383,15 +503,24 @@ class HTTPClient:
form: Optional[Iterable[Dict[str, Any]]] = None,
**kwargs: Any,
) -> Any:
bucket = route.bucket
method = route.method
url = route.url
route_key = route.key
bucket_hash = None
try:
bucket_hash = self._bucket_hashes[route_key]
except KeyError:
key = route_key + route.major_parameters
mapping = self._oneshots
else:
key = bucket_hash + route.major_parameters
mapping = self._buckets
lock = self._locks.get(bucket)
if lock is None:
lock = asyncio.Lock()
if bucket is not None:
self._locks[bucket] = lock
try:
ratelimit = mapping[key]
except KeyError:
mapping[key] = ratelimit = Ratelimit()
# header creation
headers: Dict[str, str] = {
@ -427,8 +556,7 @@ class HTTPClient:
response: Optional[aiohttp.ClientResponse] = None
data: Optional[Union[Dict[str, Any], str]] = None
await lock.acquire()
with MaybeUnlock(lock) as maybe_lock:
async with ratelimit:
for tries in range(5):
if files:
for f in files:
@ -448,14 +576,46 @@ class HTTPClient:
# even errors have text involved in them so this is safe to call
data = await json_or_text(response)
# check if we have rate limit header information
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status != 429:
# we've depleted our current bucket
delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock)
_log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta)
maybe_lock.defer()
self.loop.call_later(delta, lock.release)
# Update and use rate limit information if the bucket header is present
discord_hash = response.headers.get('X-Ratelimit-Bucket')
# I am unsure if X-Ratelimit-Bucket is always available
# However, X-Ratelimit-Remaining has been a consistent cornerstone that worked
has_ratelimit_headers = 'X-Ratelimit-Remaining' in response.headers
if discord_hash is not None:
# If the hash Discord has provided is somehow different from our current hash something changed
if bucket_hash != discord_hash:
if bucket_hash is not None:
# If the previous hash was an actual Discord hash then this means the
# hash has changed sporadically.
# This can be due to two reasons
# 1. It's a sub-ratelimit which is hard to handle
# 2. The rate limit information genuinely changed
# There is no good way to discern these, Discord doesn't provide a way to do so.
# At best, there will be some form of logging to help catch it.
# Alternating sub-ratelimits means that the requests oscillate between
# different underlying rate limits -- this can lead to unexpected 429s
# It is unavoidable.
fmt = 'A route (%s) has changed hashes: %s -> %s.'
_log.debug(fmt, route_key, bucket_hash, discord_hash)
self._bucket_hashes[route_key] = discord_hash
recalculated_key = discord_hash + route.major_parameters
self._buckets[recalculated_key] = ratelimit
self._buckets.pop(key, None)
elif route_key not in self._bucket_hashes:
fmt = '%s has found its initial rate limit bucket hash (%s).'
_log.debug(fmt, route_key, discord_hash)
self._bucket_hashes[route_key] = discord_hash
self._buckets[discord_hash + route.major_parameters] = ratelimit
if has_ratelimit_headers:
if response.status != 429:
ratelimit.update(response, use_clock=self.use_clock)
if ratelimit.remaining == 0:
_log.debug(
'A rate limit bucket (%s) has been exhausted. Pre-emptively rate limiting...',
discord_hash or route_key,
)
# the request was successful so just return the text/json
if 300 > response.status >= 200:
@ -468,11 +628,17 @@ class HTTPClient:
# Banned by Cloudflare more than likely.
raise HTTPException(response, data)
fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"'
fmt = 'We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds.'
# sleep a bit
retry_after: float = data['retry_after']
_log.warning(fmt, retry_after, bucket, stack_info=True)
_log.warning(fmt, method, url, retry_after, stack_info=True)
_log.debug(
'Rate limit is being handled by bucket hash %s with %r major parameters',
bucket_hash,
route.major_parameters,
)
# check if it's a global rate limit
is_global = data.get('global', False)

Loading…
Cancel
Save