diff --git a/discord/client.py b/discord/client.py index 862321559..0fcdcd488 100644 --- a/discord/client.py +++ b/discord/client.py @@ -453,11 +453,14 @@ class Client: while True: try: await self.ws.poll_event() - except ResumeWebSocket: - log.info('Got a request to RESUME the websocket.') + except ReconnectWebSocket as e: + log.info('Got a request to %s the websocket.', e.op) self.dispatch('disconnect') + if not e.resume: + await asyncio.sleep(5.0) + coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id, - sequence=self.ws.sequence, resume=True) + sequence=self.ws.sequence, resume=e.resume) self.ws = await asyncio.wait_for(coro, timeout=180.0) async def connect(self, *, reconnect=True): diff --git a/discord/gateway.py b/discord/gateway.py index 2ddd2a236..c2a432f8d 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -50,13 +50,15 @@ __all__ = ( 'KeepAliveHandler', 'VoiceKeepAliveHandler', 'DiscordVoiceWebSocket', - 'ResumeWebSocket', + 'ReconnectWebSocket', ) -class ResumeWebSocket(Exception): - """Signals to initialise via RESUME opcode instead of IDENTIFY.""" - def __init__(self, shard_id): +class ReconnectWebSocket(Exception): + """Signals to safely reconnect the websocket.""" + def __init__(self, shard_id, *, resume=True): self.shard_id = shard_id + self.resume = resume + self.op = 'RESUME' if resume else 'IDENTIFY' EventListener = namedtuple('EventListener', 'predicate event result future') @@ -385,7 +387,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # internal exception signalling to reconnect. log.debug('Received RECONNECT opcode.') await self.close() - raise ResumeWebSocket(self.shard_id) + raise ReconnectWebSocket(self.shard_id) if op == self.HEARTBEAT_ACK: self._keep_alive.ack() @@ -406,16 +408,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): if op == self.INVALIDATE_SESSION: if data is True: - await asyncio.sleep(5.0) await self.close() - raise ResumeWebSocket(self.shard_id) + raise ReconnectWebSocket(self.shard_id) self.sequence = None self.session_id = None log.info('Shard ID %s session has been invalidated.', self.shard_id) - await asyncio.sleep(5.0) - await self.identify() - return + await self.close(code=1000) + raise ReconnectWebSocket(self.shard_id, resume=False) log.warning('Unknown OP code %s.', op) return @@ -489,7 +489,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): except websockets.exceptions.ConnectionClosed as exc: if self._can_handle_close(exc.code): log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason) - raise ResumeWebSocket(self.shard_id) from exc + raise ReconnectWebSocket(self.shard_id) from exc else: log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason) raise ConnectionClosed(exc, shard_id=self.shard_id) from exc diff --git a/discord/shard.py b/discord/shard.py index 6d599dab6..ad564bb29 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -33,61 +33,58 @@ import websockets from .state import AutoShardedConnectionState from .client import Client from .gateway import * -from .errors import ClientException, InvalidArgument +from .errors import ClientException, InvalidArgument, ConnectionClosed from . import utils from .enums import Status log = logging.getLogger(__name__) +class EventType: + Close = 0 + Resume = 1 + Identify = 2 + class Shard: def __init__(self, ws, client): self.ws = ws self._client = client self._dispatch = client.dispatch + self._queue = client._queue self.loop = self._client.loop - self._current = self.loop.create_future() - self._current.set_result(None) # we just need an already done future - self._pending = asyncio.Event() - self._pending_task = None + self._task = None @property def id(self): return self.ws.shard_id - def is_pending(self): - return not self._pending.is_set() - - def complete_pending_reads(self): - self._pending.set() - - async def _pending_reads(self): - try: - while self.is_pending(): - await self.poll() - except asyncio.CancelledError: - pass - - def launch_pending_reads(self): - self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop) - - def wait(self): - return self._pending_task + def launch(self): + self._task = self.loop.create_task(self.worker()) - async def poll(self): - try: - await self.ws.poll_event() - except ResumeWebSocket: - log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id) - coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id, - session=self.ws.session_id, sequence=self.ws.sequence) - self._dispatch('disconnect') - self.ws = await asyncio.wait_for(coro, timeout=180.0) - - def get_future(self): - if self._current.done(): - self._current = asyncio.ensure_future(self.poll(), loop=self.loop) + async def worker(self): + while True: + try: + await self.ws.poll_event() + except ReconnectWebSocket as e: + etype = EventType.resume if e.resume else EventType.identify + self._queue.put_nowait((etype, self, e)) + break + except ConnectionClosed as e: + self._queue.put_nowait((EventType.close, self, e)) + break + + async def reconnect(self, exc): + if self._task is not None and not self._task.done(): + self._task.cancel() + + log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) + if not exc.resume: + await asyncio.sleep(5.0) - return self._current + coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, + session=self.ws.session_id, sequence=self.ws.sequence) + self._dispatch('disconnect') + self.ws = await asyncio.wait_for(coro, timeout=180.0) + self.launch() class AutoShardedClient(Client): """A client similar to :class:`Client` except it handles the complications @@ -134,6 +131,7 @@ class AutoShardedClient(Client): # the key is the shard_id self.shards = {} self._connection._get_websocket = self._get_websocket + self._queue = asyncio.PriorityQueue() def _get_websocket(self, guild_id=None, *, shard_id=None): if shard_id is None: @@ -220,8 +218,10 @@ class AutoShardedClient(Client): # keep reading the shard while others connect self.shards[shard_id] = ret = Shard(ws, self) - ret.launch_pending_reads() - await asyncio.sleep(5.0) + ret.launch() + + if len(self.shards) == self.shard_count: + self._connection.shards_launched.set() async def launch_shards(self): if self.shard_count is None: @@ -234,26 +234,29 @@ class AutoShardedClient(Client): shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) self._connection.shard_ids = shard_ids + last_shard_id = shard_ids[-1] for shard_id in shard_ids: await self.launch_shard(gateway, shard_id) + if shard_id != last_shard_id: + await asyncio.sleep(5.0) - shards_to_wait_for = [] - for shard in self.shards.values(): - shard.complete_pending_reads() - shards_to_wait_for.append(shard.wait()) + # shards_to_wait_for = [] + # for shard in self.shards.values(): + # shard.complete_pending_reads() + # shards_to_wait_for.append(shard.wait()) - # wait for all pending tasks to finish - await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) + # # wait for all pending tasks to finish + # await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) async def _connect(self): await self.launch_shards() while True: - pollers = [shard.get_future() for shard in self.shards.values()] - done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED) - for f in done: - # we wanna re-raise to the main Client.connect handler if applicable - f.result() + etype, shard, exc = await self._queue.get() + if etype == EventType.close: + raise exc + elif etype in (EventType.identify, EventType.resume): + await shard.reconnect(exc) async def close(self): """|coro| diff --git a/discord/state.py b/discord/state.py index f84d85ba2..6148889d9 100644 --- a/discord/state.py +++ b/discord/state.py @@ -1047,6 +1047,7 @@ class AutoShardedConnectionState(ConnectionState): super().__init__(*args, **kwargs) self._ready_task = None self.shard_ids = () + self.shards_launched = asyncio.Event() async def chunker(self, guild_id, query='', limit=0, *, shard_id=None, nonce=None): ws = self._get_websocket(guild_id, shard_id=shard_id) @@ -1073,6 +1074,7 @@ class AutoShardedConnectionState(ConnectionState): log.info('Finished requesting guild member chunks for %d guilds.', len(guilds)) async def _delay_ready(self): + await self.shards_launched.wait() launch = self._ready_state.launch while True: # this snippet of code is basically waiting 2 * shard_ids seconds