@ -28,10 +28,13 @@ import asyncio
import itertools
import logging
import aiohttp
from . state import AutoShardedConnectionState
from . client import Client
from . backoff import ExponentialBackoff
from . gateway import *
from . errors import ClientException , InvalidArgument , ConnectionClosed
from . errors import ClientException , InvalidArgument , HTTPException , GatewayNotFound , ConnectionClosed
from . import utils
from . enums import Status
@ -39,8 +42,9 @@ log = logging.getLogger(__name__)
class EventType :
close = 0
resume = 1
identify = 2
reconnect = 1
resume = 2
identify = 3
class EventItem :
__slots__ = ( ' type ' , ' shard ' , ' error ' )
@ -70,7 +74,18 @@ class Shard:
self . _dispatch = client . dispatch
self . _queue = client . _queue
self . loop = self . _client . loop
self . _disconnect = False
self . _reconnect = client . _reconnect
self . _backoff = ExponentialBackoff ( )
self . _task = None
self . _handled_exceptions = (
OSError ,
HTTPException ,
GatewayNotFound ,
ConnectionClosed ,
aiohttp . ClientError ,
asyncio . TimeoutError ,
)
@property
def id ( self ) :
@ -79,6 +94,33 @@ class Shard:
def launch ( self ) :
self . _task = self . loop . create_task ( self . worker ( ) )
def _cancel_task ( self ) :
if self . _task is not None and not self . _task . done ( ) :
self . _task . cancel ( )
async def close ( self ) :
self . _cancel_task ( )
await self . ws . close ( code = 1000 )
async def _handle_disconnect ( self , e ) :
self . _dispatch ( ' disconnect ' )
if not self . _reconnect :
self . _queue . put_nowait ( EventItem ( EventType . close , self , e ) )
return
if self . _client . is_closed ( ) :
return
if isinstance ( e , ConnectionClosed ) :
if e . code != 1000 :
self . _queue . put_nowait ( EventItem ( EventType . close , self , e ) )
return
retry = self . _backoff . delay ( )
log . error ( ' Attempting a reconnect for shard ID %s in %.2f s ' , self . id , retry , exc_info = e )
await asyncio . sleep ( retry )
self . _queue . put_nowait ( EventItem ( EventType . reconnect , self , e ) )
async def worker ( self ) :
while not self . _client . is_closed ( ) :
try :
@ -87,14 +129,12 @@ class Shard:
etype = EventType . resume if e . resume else EventType . identify
self . _queue . put_nowait ( EventItem ( etype , self , e ) )
break
except ConnectionClosed as e :
self . _queue . put_nowait ( EventItem ( EventType . close , self , e ) )
except self . _handled_exceptions as e :
await self . _handle_disconnect ( e )
break
async def reconnect ( self , exc ) :
if self . _task is not None and not self . _task . done ( ) :
self . _task . cancel ( )
async def reidentify ( self , exc ) :
self . _cancel_task ( )
log . info ( ' Got a request to %s the websocket at Shard ID %s . ' , exc . op , self . id )
coro = DiscordWebSocket . from_client ( self . _client , resume = exc . resume , shard_id = self . id ,
session = self . ws . session_id , sequence = self . ws . sequence )
@ -102,6 +142,16 @@ class Shard:
self . ws = await asyncio . wait_for ( coro , timeout = 180.0 )
self . launch ( )
async def reconnect ( self ) :
self . _cancel_task ( )
try :
coro = DiscordWebSocket . from_client ( self . _client , shard_id = self . id )
self . ws = await asyncio . wait_for ( coro , timeout = 180.0 )
except self . _handled_exceptions as e :
await self . _handle_disconnect ( e )
else :
self . launch ( )
class AutoShardedClient ( Client ) :
""" A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
@ -235,15 +285,21 @@ class AutoShardedClient(Client):
self . _connection . shards_launched . set ( )
async def _connect ( self ) :
async def connect ( self , * , reconnect = True ) :
self . _reconnect = reconnect
await self . launch_shards ( )
while True :
while not self . is_closed ( ) :
item = await self . _queue . get ( )
if item . type == EventType . close :
raise item . error
await self . close ( )
if isinstance ( item . error , ConnectionClosed ) and item . error . code != 1000 :
raise item . error
return
elif item . type in ( EventType . identify , EventType . resume ) :
await item . shard . reconnect ( item . error )
await item . shard . reidentify ( item . error )
elif item . type == EventType . reconnect :
await item . shard . reconnect ( )
async def close ( self ) :
""" |coro|
@ -261,7 +317,7 @@ class AutoShardedClient(Client):
except Exception :
pass
to_close = [ asyncio . ensure_future ( shard . ws . close ( code = 1000 ) , loop = self . loop ) for shard in self . shards . values ( ) ]
to_close = [ asyncio . ensure_future ( shard . close ( ) , loop = self . loop ) for shard in self . shards . values ( ) ]
if to_close :
await asyncio . wait ( to_close )