Browse Source

Change the way shards are launched in AutoShardedClient.

pull/447/head
Rapptz 8 years ago
parent
commit
b5bed9ef33
  1. 29
      discord/gateway.py
  2. 64
      discord/shard.py

29
discord/gateway.py

@ -214,35 +214,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
else:
return ws
@classmethod
@asyncio.coroutine
def from_sharded_client(cls, client):
if client.shard_count is None:
client.shard_count, gateway = yield from client.http.get_bot_gateway()
else:
gateway = yield from client.http.get_gateway()
ret = []
client.connection.shard_count = client.shard_count
for shard_id in range(client.shard_count):
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
ws.token = client.http.token
ws._connection = client.connection
ws._dispatch = client.dispatch
ws.gateway = gateway
ws.shard_id = shard_id
ws.shard_count = client.shard_count
# OP HELLO
yield from ws.poll_event()
yield from ws.identify()
ret.append(ws)
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
yield from asyncio.sleep(5.0, loop=client.loop)
return ret
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.

64
discord/shard.py

@ -32,6 +32,7 @@ from . import compat
import asyncio
import logging
import websockets
log = logging.getLogger(__name__)
@ -93,8 +94,10 @@ class AutoShardedClient(Client):
syncer=self._syncer, http=self.http, loop=self.loop, **kwargs)
# instead of a single websocket, we have multiple
# the index is the shard_id
self.shards = []
# the key is the shard_id
self.shards = {}
self._still_sharding = True
@asyncio.coroutine
def request_offline_members(self, guild, *, shard_id=None):
@ -135,6 +138,56 @@ class AutoShardedClient(Client):
ws = self.shards[shard_id].ws
yield from ws.send_as_json(payload)
@asyncio.coroutine
def pending_reads(self, shard):
try:
while self._still_sharding:
yield from shard.poll()
except asyncio.CancelledError:
pass
@asyncio.coroutine
def launch_shard(self, gateway, shard_id):
try:
ws = yield from websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket)
except Exception as e:
import traceback
traceback.print_exc()
log.info('Failed to connect for shard_id: %s. Retrying...' % shard_id)
yield from asyncio.sleep(5.0, loop=self.loop)
yield from self.launch_shard(gateway, shard_id)
ws.token = self.http.token
ws._connection = self.connection
ws._dispatch = self.dispatch
ws.gateway = gateway
ws.shard_id = shard_id
ws.shard_count = self.shard_count
# OP HELLO
yield from ws.poll_event()
yield from ws.identify()
log.info('Sent IDENTIFY payload to create the websocket for shard_id: %s' % shard_id)
# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
compat.create_task(self.pending_reads(ret), loop=self.loop)
yield from asyncio.sleep(5.0, loop=self.loop)
@asyncio.coroutine
def launch_shards(self):
if self.shard_count is None:
self.shard_count, gateway = yield from self.http.get_bot_gateway()
else:
gateway = yield from self.http.get_gateway()
self.connection.shard_count = self.shard_count
for shard_id in range(self.shard_count):
yield from self.launch_shard(gateway, shard_id)
self._still_sharding = False
@asyncio.coroutine
def connect(self):
"""|coro|
@ -150,11 +203,10 @@ class AutoShardedClient(Client):
ConnectionClosed
The websocket connection has been terminated.
"""
ret = yield from DiscordWebSocket.from_sharded_client(self)
self.shards = [Shard(ws, self) for ws in ret]
yield from self.launch_shards()
while not self.is_closed:
pollers = [shard.get_future() for shard in self.shards]
pollers = [shard.get_future() for shard in self.shards.values()]
yield from asyncio.wait(pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED)
@asyncio.coroutine
@ -166,7 +218,7 @@ class AutoShardedClient(Client):
if self.is_closed:
return
for shard in self.shards:
for shard in self.shards.values():
yield from shard.ws.close()
yield from self.http.close()

Loading…
Cancel
Save