diff --git a/discord/gateway.py b/discord/gateway.py index fcba2dfcc..8180f4ec3 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -214,35 +214,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): else: return ws - @classmethod - @asyncio.coroutine - def from_sharded_client(cls, client): - if client.shard_count is None: - client.shard_count, gateway = yield from client.http.get_bot_gateway() - else: - gateway = yield from client.http.get_gateway() - - ret = [] - client.connection.shard_count = client.shard_count - - for shard_id in range(client.shard_count): - ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) - ws.token = client.http.token - ws._connection = client.connection - ws._dispatch = client.dispatch - ws.gateway = gateway - ws.shard_id = shard_id - ws.shard_count = client.shard_count - - # OP HELLO - yield from ws.poll_event() - yield from ws.identify() - ret.append(ws) - log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) - yield from asyncio.sleep(5.0, loop=client.loop) - - return ret - def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. diff --git a/discord/shard.py b/discord/shard.py index 2be0ea128..df0973b34 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -32,6 +32,7 @@ from . import compat import asyncio import logging +import websockets log = logging.getLogger(__name__) @@ -93,8 +94,10 @@ class AutoShardedClient(Client): syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) # instead of a single websocket, we have multiple - # the index is the shard_id - self.shards = [] + # the key is the shard_id + self.shards = {} + + self._still_sharding = True @asyncio.coroutine def request_offline_members(self, guild, *, shard_id=None): @@ -135,6 +138,56 @@ class AutoShardedClient(Client): ws = self.shards[shard_id].ws yield from ws.send_as_json(payload) + @asyncio.coroutine + def pending_reads(self, shard): + try: + while self._still_sharding: + yield from shard.poll() + except asyncio.CancelledError: + pass + + @asyncio.coroutine + def launch_shard(self, gateway, shard_id): + try: + ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket) + except Exception as e: + import traceback + traceback.print_exc() + log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id) + yield from asyncio.sleep(5.0, loop=self.loop) + yield from self.launch_shard(gateway, shard_id) + + ws.token = self.http.token + ws._connection = self.connection + ws._dispatch = self.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = self.shard_count + + # OP HELLO + yield from ws.poll_event() + yield from ws.identify() + log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) + + # keep reading the shard while others connect + self.shards[shard_id] = ret = Shard(ws, self) + compat.create_task(self.pending_reads(ret), loop=self.loop) + yield from asyncio.sleep(5.0, loop=self.loop) + + @asyncio.coroutine + def launch_shards(self): + if self.shard_count is None: + self.shard_count, gateway = yield from self.http.get_bot_gateway() + else: + gateway = yield from self.http.get_gateway() + + self.connection.shard_count = self.shard_count + + for shard_id in range(self.shard_count): + yield from self.launch_shard(gateway, shard_id) + + self._still_sharding = False + @asyncio.coroutine def connect(self): """|coro| @@ -150,11 +203,10 @@ class AutoShardedClient(Client): ConnectionClosed The websocket connection has been terminated. """ - ret = yield from DiscordWebSocket.from_sharded_client(self) - self.shards = [Shard(ws, self) for ws in ret] + yield from self.launch_shards() while not self.is_closed: - pollers = [shard.get_future() for shard in self.shards] + pollers = [shard.get_future() for shard in self.shards.values()] yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) @asyncio.coroutine @@ -166,7 +218,7 @@ class AutoShardedClient(Client): if self.is_closed: return - for shard in self.shards: + for shard in self.shards.values(): yield from shard.ws.close() yield from self.http.close()