@ -52,7 +52,7 @@ import datetime
import aiohttp
from . errors import HTTPException , Forbidden , NotFound , LoginFailure , DiscordServerError , GatewayNotFound
from . errors import HTTPException , RateLimited , Forbidden , NotFound , LoginFailure , DiscordServerError , GatewayNotFound
from . gateway import DiscordClientWebSocketResponse
from . file import File
from . mentions import AllowedMentions
@ -328,13 +328,15 @@ class Ratelimit:
design is to increase throughput of requests being sent concurrently rather than forcing
everything into a single lock queue per route .
"""
def __init__ ( self ) - > None :
def __init__ ( self , max_ratelimit_timeout : Optional [ float ] ) - > None :
self . limit : int = 1
self . remaining : int = self . limit
self . outgoing : int = 0
self . reset_after : float = 0.0
self . expires : Optional [ float ] = None
self . dirty : bool = False
self . _max_ratelimit_timeout : Optional [ float ] = max_ratelimit_timeout
self . _loop : asyncio . AbstractEventLoop = asyncio . get_running_loop ( )
self . _pending_requests : deque [ asyncio . Future [ Any ] ] = deque ( )
# Only a single rate limit object should be sleeping at a time.
@ -381,12 +383,15 @@ class Ratelimit:
future . set_result ( None )
break
def _wake ( self , count : int = 1 ) - > None :
def _wake ( self , count : int = 1 , * , exception : Optional [ RateLimited ] = None ) - > None :
awaken = 0
while self . _pending_requests :
future = self . _pending_requests . popleft ( )
if not future . done ( ) :
future . set_result ( None )
if exception :
future . set_exception ( exception )
else :
future . set_result ( None )
self . _has_just_awaken = True
awaken + = 1
@ -394,10 +399,14 @@ class Ratelimit:
break
async def _refresh ( self ) - > None :
error = self . _max_ratelimit_timeout and self . reset_after > self . _max_ratelimit_timeout
exception = RateLimited ( self . reset_after ) if error else None
async with self . _sleeping :
await asyncio . sleep ( self . reset_after )
if not error :
await asyncio . sleep ( self . reset_after )
self . reset ( )
self . _wake ( self . remaining )
self . _wake ( self . remaining , exception = exception )
def is_expired ( self ) - > bool :
return self . expires is not None and self . _loop . time ( ) > self . expires
@ -406,6 +415,12 @@ class Ratelimit:
if self . is_expired ( ) :
self . reset ( )
if self . _max_ratelimit_timeout is not None and self . expires is not None :
# Check if we can pre-emptively block this request for having too large of a timeout
current_reset_after = self . expires - self . _loop . time ( )
if current_reset_after > self . _max_ratelimit_timeout :
raise RateLimited ( current_reset_after )
while self . remaining < = 0 :
future = self . _loop . create_future ( )
self . _pending_requests . append ( future )
@ -433,7 +448,12 @@ class Ratelimit:
if tokens < = 0 :
await self . _refresh ( )
elif self . _pending_requests :
self . _wake ( tokens )
exception = (
RateLimited ( self . reset_after )
if self . _max_ratelimit_timeout and self . reset_after > self . _max_ratelimit_timeout
else None
)
self . _wake ( tokens , exception = exception )
# For some reason, the Discord voice websocket expects this header to be
@ -453,6 +473,7 @@ class HTTPClient:
proxy_auth : Optional [ aiohttp . BasicAuth ] = None ,
unsync_clock : bool = True ,
http_trace : Optional [ aiohttp . TraceConfig ] = None ,
max_ratelimit_timeout : Optional [ float ] = None ,
) - > None :
self . loop : asyncio . AbstractEventLoop = loop
self . connector : aiohttp . BaseConnector = connector or MISSING
@ -472,6 +493,7 @@ class HTTPClient:
self . proxy_auth : Optional [ aiohttp . BasicAuth ] = proxy_auth
self . http_trace : Optional [ aiohttp . TraceConfig ] = http_trace
self . use_clock : bool = not unsync_clock
self . max_ratelimit_timeout : Optional [ float ] = max ( 30.0 , max_ratelimit_timeout ) if max_ratelimit_timeout else None
user_agent = ' DiscordBot (https://github.com/Rapptz/discord.py {0} ) Python/ {1[0]} . {1[1]} aiohttp/ {2} '
self . user_agent : str = user_agent . format ( __version__ , sys . version_info , aiohttp . __version__ )
@ -520,7 +542,7 @@ class HTTPClient:
try :
ratelimit = mapping [ key ]
except KeyError :
mapping [ key ] = ratelimit = Ratelimit ( )
mapping [ key ] = ratelimit = Ratelimit ( self . max_ratelimit_timeout )
# header creation
headers : Dict [ str , str ] = {
@ -628,10 +650,17 @@ class HTTPClient:
# Banned by Cloudflare more than likely.
raise HTTPException ( response , data )
fmt = ' We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds. '
# sleep a bit
retry_after : float = data [ ' retry_after ' ]
if self . max_ratelimit_timeout and retry_after > self . max_ratelimit_timeout :
_log . warning (
' We are being rate limited. %s %s responded with 429. Timeout of %.2f was too long, erroring instead. ' ,
method ,
url ,
retry_after ,
)
raise RateLimited ( retry_after )
fmt = ' We are being rate limited. %s %s responded with 429. Retrying in %.2f seconds. '
_log . warning ( fmt , method , url , retry_after , stack_info = True )
_log . debug (