|
|
@ -32,6 +32,7 @@ from . import compat |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import logging |
|
|
|
import websockets |
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
@ -93,8 +94,10 @@ class AutoShardedClient(Client): |
|
|
|
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) |
|
|
|
|
|
|
|
# instead of a single websocket, we have multiple |
|
|
|
# the index is the shard_id |
|
|
|
self.shards = [] |
|
|
|
# the key is the shard_id |
|
|
|
self.shards = {} |
|
|
|
|
|
|
|
self._still_sharding = True |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def request_offline_members(self, guild, *, shard_id=None): |
|
|
@ -135,6 +138,56 @@ class AutoShardedClient(Client): |
|
|
|
ws = self.shards[shard_id].ws |
|
|
|
yield from ws.send_as_json(payload) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def pending_reads(self, shard): |
|
|
|
try: |
|
|
|
while self._still_sharding: |
|
|
|
yield from shard.poll() |
|
|
|
except asyncio.CancelledError: |
|
|
|
pass |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def launch_shard(self, gateway, shard_id): |
|
|
|
try: |
|
|
|
ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket) |
|
|
|
except Exception as e: |
|
|
|
import traceback |
|
|
|
traceback.print_exc() |
|
|
|
log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id) |
|
|
|
yield from asyncio.sleep(5.0, loop=self.loop) |
|
|
|
yield from self.launch_shard(gateway, shard_id) |
|
|
|
|
|
|
|
ws.token = self.http.token |
|
|
|
ws._connection = self.connection |
|
|
|
ws._dispatch = self.dispatch |
|
|
|
ws.gateway = gateway |
|
|
|
ws.shard_id = shard_id |
|
|
|
ws.shard_count = self.shard_count |
|
|
|
|
|
|
|
# OP HELLO |
|
|
|
yield from ws.poll_event() |
|
|
|
yield from ws.identify() |
|
|
|
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id) |
|
|
|
|
|
|
|
# keep reading the shard while others connect |
|
|
|
self.shards[shard_id] = ret = Shard(ws, self) |
|
|
|
compat.create_task(self.pending_reads(ret), loop=self.loop) |
|
|
|
yield from asyncio.sleep(5.0, loop=self.loop) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def launch_shards(self): |
|
|
|
if self.shard_count is None: |
|
|
|
self.shard_count, gateway = yield from self.http.get_bot_gateway() |
|
|
|
else: |
|
|
|
gateway = yield from self.http.get_gateway() |
|
|
|
|
|
|
|
self.connection.shard_count = self.shard_count |
|
|
|
|
|
|
|
for shard_id in range(self.shard_count): |
|
|
|
yield from self.launch_shard(gateway, shard_id) |
|
|
|
|
|
|
|
self._still_sharding = False |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def connect(self): |
|
|
|
"""|coro| |
|
|
@ -150,11 +203,10 @@ class AutoShardedClient(Client): |
|
|
|
ConnectionClosed |
|
|
|
The websocket connection has been terminated. |
|
|
|
""" |
|
|
|
ret = yield from DiscordWebSocket.from_sharded_client(self) |
|
|
|
self.shards = [Shard(ws, self) for ws in ret] |
|
|
|
yield from self.launch_shards() |
|
|
|
|
|
|
|
while not self.is_closed: |
|
|
|
pollers = [shard.get_future() for shard in self.shards] |
|
|
|
pollers = [shard.get_future() for shard in self.shards.values()] |
|
|
|
yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
@ -166,7 +218,7 @@ class AutoShardedClient(Client): |
|
|
|
if self.is_closed: |
|
|
|
return |
|
|
|
|
|
|
|
for shard in self.shards: |
|
|
|
for shard in self.shards.values(): |
|
|
|
yield from shard.ws.close() |
|
|
|
|
|
|
|
yield from self.http.close() |
|
|
|