@ -28,7 +28,7 @@ from .state import AutoShardedConnectionState
from . client import Client
from . client import Client
from . gateway import *
from . gateway import *
from . errors import ClientException , InvalidArgument
from . errors import ClientException , InvalidArgument
from . import compat
from . import compat , utils
from . enums import Status
from . enums import Status
import asyncio
import asyncio
@ -45,11 +45,32 @@ class Shard:
self . loop = self . _client . loop
self . loop = self . _client . loop
self . _current = compat . create_future ( self . loop )
self . _current = compat . create_future ( self . loop )
self . _current . set_result ( None ) # we just need an already done future
self . _current . set_result ( None ) # we just need an already done future
self . _pending = asyncio . Event ( loop = self . loop )
self . _pending_task = None
@property
@property
def id ( self ) :
def id ( self ) :
return self . ws . shard_id
return self . ws . shard_id
def is_pending ( self ) :
return not self . _pending . is_set ( )
def complete_pending_reads ( self ) :
self . _pending . set ( )
def _pending_reads ( self ) :
try :
while self . is_pending ( ) :
yield from self . poll ( )
except asyncio . CancelledError :
pass
def launch_pending_reads ( self ) :
self . _pending_task = compat . create_task ( self . _pending_reads ( ) , loop = self . loop )
def wait ( self ) :
return self . _pending_task
@asyncio . coroutine
@asyncio . coroutine
def poll ( self ) :
def poll ( self ) :
try :
try :
@ -127,7 +148,6 @@ class AutoShardedClient(Client):
return self . shards [ i ] . ws
return self . shards [ i ] . ws
self . _connection . _get_websocket = _get_websocket
self . _connection . _get_websocket = _get_websocket
self . _still_sharding = True
@asyncio . coroutine
@asyncio . coroutine
def _chunker ( self , guild , * , shard_id = None ) :
def _chunker ( self , guild , * , shard_id = None ) :
@ -199,14 +219,6 @@ class AutoShardedClient(Client):
sub_guilds = list ( sub_guilds )
sub_guilds = list ( sub_guilds )
yield from self . _connection . request_offline_members ( sub_guilds , shard_id = shard_id )
yield from self . _connection . request_offline_members ( sub_guilds , shard_id = shard_id )
@asyncio . coroutine
def pending_reads ( self , shard ) :
try :
while self . _still_sharding :
yield from shard . poll ( )
except asyncio . CancelledError :
pass
@asyncio . coroutine
@asyncio . coroutine
def launch_shard ( self , gateway , shard_id ) :
def launch_shard ( self , gateway , shard_id ) :
try :
try :
@ -235,7 +247,7 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect
# keep reading the shard while others connect
self . shards [ shard_id ] = ret = Shard ( ws , self )
self . shards [ shard_id ] = ret = Shard ( ws , self )
compat . create_task ( self . pending_reads ( ret ) , loop = self . loop )
ret . launch_pending_reads ( )
yield from asyncio . sleep ( 5.0 , loop = self . loop )
yield from asyncio . sleep ( 5.0 , loop = self . loop )
@asyncio . coroutine
@asyncio . coroutine
@ -252,7 +264,13 @@ class AutoShardedClient(Client):
for shard_id in shard_ids :
for shard_id in shard_ids :
yield from self . launch_shard ( gateway , shard_id )
yield from self . launch_shard ( gateway , shard_id )
self . _still_sharding = False
shards_to_wait_for = [ ]
for shard in self . shards . values ( ) :
shard . complete_pending_reads ( )
shards_to_wait_for . append ( shard . wait ( ) )
# wait for all pending tasks to finish
yield from utils . sane_wait_for ( shards_to_wait_for , timeout = 300.0 , loop = self . loop )
@asyncio . coroutine
@asyncio . coroutine
def _connect ( self ) :
def _connect ( self ) :