@ -46,7 +46,9 @@ from typing import (
Union ,
)
from urllib . parse import quote as _uriquote
from collections import deque
import weakref
import datetime
import aiohttp
@ -283,9 +285,12 @@ def _set_api_version(value: int):
class Route :
BASE : ClassVar [ str ] = ' https://discord.com/api/v10 '
def __init__ ( self , method : str , path : str , * * parameters : Any ) - > None :
def __init__ ( self , method : str , path : str , * , metadata : Optional [ str ] = None , * * parameters : Any ) - > None :
self . path : str = path
self . method : str = method
# Metadata is a special string used to differentiate between known sub rate limits
# Since these can't be handled generically, this is the next best way to do so.
self . metadata : Optional [ str ] = metadata
url = self . BASE + self . path
if parameters :
url = url . format_map ( { k : _uriquote ( v ) if isinstance ( v , str ) else v for k , v in parameters . items ( ) } )
@ -298,30 +303,137 @@ class Route:
self . webhook_token : Optional [ str ] = parameters . get ( ' webhook_token ' )
@property
def bucket ( self ) - > str :
# the bucket is just method + path w/ major parameters
return f ' { self . channel_id } : { self . guild_id } : { self . path } '
def key ( self ) - > str :
""" The bucket key is used to represent the route in various mappings. """
if self . metadata :
return f ' { self . method } { self . path } : { self . metadata } '
return f ' { self . method } { self . path } '
@property
def major_parameters ( self ) - > str :
""" Returns the major parameters formatted a string.
class MaybeUnlock :
def __init__ ( self , lock : asyncio . Lock ) - > None :
self . lock : asyncio . Lock = lock
self . _unlock : bool = True
This needs to be appended to a bucket hash to constitute as a full rate limit key .
"""
return ' + ' . join (
str ( k ) for k in ( self . channel_id , self . guild_id , self . webhook_id , self . webhook_token ) if k is not None
)
def __enter__ ( self ) - > Self :
return self
def defer ( self ) - > None :
self . _unlock = False
class Ratelimit :
""" Represents a Discord rate limit.
This is similar to a semaphore except tailored to Discord ' s rate limits. This is aware of
the expiry of a token window , along with the number of tokens available . The goal of this
design is to increase throughput of requests being sent concurrently rather than forcing
everything into a single lock queue per route .
"""
def __init__ ( self ) - > None :
self . limit : int = 1
self . remaining : int = self . limit
self . outgoing : int = 0
self . reset_after : float = 0.0
self . expires : Optional [ float ] = None
self . dirty : bool = False
self . _loop : asyncio . AbstractEventLoop = asyncio . get_running_loop ( )
self . _pending_requests : deque [ asyncio . Future [ Any ] ] = deque ( )
# Only a single rate limit object should be sleeping at a time.
# The object that is sleeping is ultimately responsible for freeing the semaphore
# for the requests currently pending.
self . _sleeping : asyncio . Lock = asyncio . Lock ( )
def __repr__ ( self ) - > str :
return (
f ' <RateLimitBucket limit= { self . limit } remaining= { self . remaining } pending_requests= { len ( self . _pending_requests ) } > '
)
def __exit__ (
self ,
exc_type : Optional [ Type [ BE ] ] ,
exc : Optional [ BE ] ,
traceback : Optional [ TracebackType ] ,
) - > None :
if self . _unlock :
self . lock . release ( )
def reset ( self ) :
self . remaining = self . limit - self . outgoing
self . expires = None
self . reset_after = 0.0
self . dirty = False
def update ( self , response : aiohttp . ClientResponse , * , use_clock : bool = False ) - > None :
headers = response . headers
self . limit = int ( headers . get ( ' X-Ratelimit-Limit ' , 1 ) )
if self . dirty :
self . remaining = min ( int ( headers . get ( ' X-Ratelimit-Remaining ' , 0 ) ) , self . limit - self . outgoing )
else :
self . remaining = int ( headers . get ( ' X-Ratelimit-Remaining ' , 0 ) )
self . dirty = True
reset_after = headers . get ( ' X-Ratelimit-Reset-After ' )
if use_clock or not reset_after :
utc = datetime . timezone . utc
now = datetime . datetime . now ( utc )
reset = datetime . datetime . fromtimestamp ( float ( headers [ ' X-Ratelimit-Reset ' ] ) , utc )
self . reset_after = ( reset - now ) . total_seconds ( )
else :
self . reset_after = float ( reset_after )
self . expires = self . _loop . time ( ) + self . reset_after
def _wake_next ( self ) - > None :
while self . _pending_requests :
future = self . _pending_requests . popleft ( )
if not future . done ( ) :
future . set_result ( None )
break
def _wake ( self , count : int = 1 ) - > None :
awaken = 0
while self . _pending_requests :
future = self . _pending_requests . popleft ( )
if not future . done ( ) :
future . set_result ( None )
self . _has_just_awaken = True
awaken + = 1
if awaken > = count :
break
async def _refresh ( self ) - > None :
async with self . _sleeping :
await asyncio . sleep ( self . reset_after )
self . reset ( )
self . _wake ( self . remaining )
def is_expired ( self ) - > bool :
return self . expires is not None and self . _loop . time ( ) > self . expires
async def acquire ( self ) - > None :
if self . is_expired ( ) :
self . reset ( )
while self . remaining < = 0 :
future = self . _loop . create_future ( )
self . _pending_requests . append ( future )
try :
await future
except :
future . cancel ( )
if self . remaining > 0 and not future . cancelled ( ) :
self . _wake_next ( )
raise
self . remaining - = 1
self . outgoing + = 1
async def __aenter__ ( self ) - > Self :
await self . acquire ( )
return self
async def __aexit__ ( self , type : Type [ BE ] , value : BE , traceback : TracebackType ) - > None :
self . outgoing - = 1
tokens = self . remaining - self . outgoing
# Check whether the rate limit needs to be pre-emptively slept on
# Note that this is a Lock to prevent multiple rate limit objects from sleeping at once
if not self . _sleeping . locked ( ) :
if tokens < = 0 :
await self . _refresh ( )
elif self . _pending_requests :
self . _wake ( tokens )
# For some reason, the Discord voice websocket expects this header to be
@ -345,7 +457,15 @@ class HTTPClient:
self . loop : asyncio . AbstractEventLoop = loop
self . connector : aiohttp . BaseConnector = connector or MISSING
self . __session : aiohttp . ClientSession = MISSING # filled in static_login
self . _locks : weakref . WeakValueDictionary = weakref . WeakValueDictionary ( )
# Route key -> Bucket hash
self . _bucket_hashes : Dict [ str , str ] = { }
# Bucket Hash + Major Parameters -> Rate limit
self . _buckets : weakref . WeakValueDictionary [ str , Ratelimit ] = weakref . WeakValueDictionary ( )
# Route key + Major Parameters -> Rate limit
# Used for temporary one shot requests that don't have a bucket hash
# While I'd love to use a single mapping for these, doing this would cause the rate limit objects
# to inexplicably be evicted from the dictionary before we're done with it
self . _oneshots : weakref . WeakValueDictionary [ str , Ratelimit ] = weakref . WeakValueDictionary ( )
self . _global_over : asyncio . Event = MISSING
self . token : Optional [ str ] = None
self . proxy : Optional [ str ] = proxy
@ -383,15 +503,24 @@ class HTTPClient:
form : Optional [ Iterable [ Dict [ str , Any ] ] ] = None ,
* * kwargs : Any ,
) - > Any :
bucket = route . bucket
method = route . method
url = route . url
route_key = route . key
bucket_hash = None
try :
bucket_hash = self . _bucket_hashes [ route_key ]
except KeyError :
key = route_key + route . major_parameters
mapping = self . _oneshots
else :
key = bucket_hash + route . major_parameters
mapping = self . _buckets
lock = self . _locks . get ( bucket )
if lock is None :
lock = asyncio . Lock ( )
if bucket is not None :
self . _locks [ bucket ] = lock
try :
ratelimit = mapping [ key ]
except KeyError :
mapping [ key ] = ratelimit = Ratelimit ( )
# header creation
headers : Dict [ str , str ] = {
@ -427,8 +556,7 @@ class HTTPClient:
response : Optional [ aiohttp . ClientResponse ] = None
data : Optional [ Union [ Dict [ str , Any ] , str ] ] = None
await lock . acquire ( )
with MaybeUnlock ( lock ) as maybe_lock :
async with ratelimit :
for tries in range ( 5 ) :
if files :
for f in files :
@ -448,14 +576,46 @@ class HTTPClient:
# even errors have text involved in them so this is safe to call
data = await json_or_text ( response )
# check if we have rate limit header information
remaining = response . headers . get ( ' X-Ratelimit-Remaining ' )
if remaining == ' 0 ' and response . status != 429 :
# we've depleted our current bucket
delta = utils . _parse_ratelimit_header ( response , use_clock = self . use_clock )
_log . debug ( ' A rate limit bucket has been exhausted (bucket: %s , retry: %s ). ' , bucket , delta )
maybe_lock . defer ( )
self . loop . call_later ( delta , lock . release )
# Update and use rate limit information if the bucket header is present
discord_hash = response . headers . get ( ' X-Ratelimit-Bucket ' )
# I am unsure if X-Ratelimit-Bucket is always available
# However, X-Ratelimit-Remaining has been a consistent cornerstone that worked
has_ratelimit_headers = ' X-Ratelimit-Remaining ' in response . headers
if discord_hash is not None :
# If the hash Discord has provided is somehow different from our current hash something changed
if bucket_hash != discord_hash :
if bucket_hash is not None :
# If the previous hash was an actual Discord hash then this means the
# hash has changed sporadically.
# This can be due to two reasons
# 1. It's a sub-ratelimit which is hard to handle
# 2. The rate limit information genuinely changed
# There is no good way to discern these, Discord doesn't provide a way to do so.
# At best, there will be some form of logging to help catch it.
# Alternating sub-ratelimits means that the requests oscillate between
# different underlying rate limits -- this can lead to unexpected 429s
# It is unavoidable.
fmt = ' A route ( %s ) has changed hashes: %s -> %s . '
_log . debug ( fmt , route_key , bucket_hash , discord_hash )
self . _bucket_hashes [ route_key ] = discord_hash
recalculated_key = discord_hash + route . major_parameters
self . _buckets [ recalculated_key ] = ratelimit
self . _buckets . pop ( key , None )
elif route_key not in self . _bucket_hashes :
fmt = ' %s has found its initial rate limit bucket hash ( %s ). '
_log . debug ( fmt , route_key , discord_hash )
self . _bucket_hashes [ route_key ] = discord_hash
self . _buckets [ discord_hash + route . major_parameters ] = ratelimit
if has_ratelimit_headers :
if response . status != 429 :
ratelimit . update ( response , use_clock = self . use_clock )
if ratelimit . remaining == 0 :
_log . debug (
' A rate limit bucket ( %s ) has been exhausted. Pre-emptively rate limiting... ' ,
discord_hash or route_key ,
)
# the request was successful so just return the text/json
if 300 > response . status > = 200 :
@ -468,11 +628,17 @@ class HTTPClient:
# Banned by Cloudflare more than likely.
raise HTTPException ( response , data )
fmt = ' We are being rate limited. Retrying in %.2f seconds. Handled under the bucket " %s " '
fmt = ' We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds. '
# sleep a bit
retry_after : float = data [ ' retry_after ' ]
_log . warning ( fmt , retry_after , bucket , stack_info = True )
_log . warning ( fmt , method , url , retry_after , stack_info = True )
_log . debug (
' Rate limit is being handled by bucket hash %s with %r major parameters ' ,
bucket_hash ,
route . major_parameters ,
)
# check if it's a global rate limit
is_global = data . get ( ' global ' , False )