From f658fcf16457638d3d6fe9636eb61f587d7681ac Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 11 Apr 2020 18:37:00 -0400 Subject: [PATCH] Make every shard maintain its own reconnect loop Previously if a disconnect happened the client would get in a bad state and certain shards would be double sending due to unhandled exceptions raising back to Client.connect and causing all shards to be reconnected again. This new code overrides Client.connect to have more finer control and allow each individual shard to maintain its own reconnect loop and then serially request reconnection to ensure that IDENTIFYs are not overlapping. --- discord/shard.py | 84 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/discord/shard.py b/discord/shard.py index f817fb9ab..31777816b 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -28,10 +28,13 @@ import asyncio import itertools import logging +import aiohttp + from .state import AutoShardedConnectionState from .client import Client +from .backoff import ExponentialBackoff from .gateway import * -from .errors import ClientException, InvalidArgument, ConnectionClosed +from .errors import ClientException, InvalidArgument, HTTPException, GatewayNotFound, ConnectionClosed from . import utils from .enums import Status @@ -39,8 +42,9 @@ log = logging.getLogger(__name__) class EventType: close = 0 - resume = 1 - identify = 2 + reconnect = 1 + resume = 2 + identify = 3 class EventItem: __slots__ = ('type', 'shard', 'error') @@ -70,7 +74,18 @@ class Shard: self._dispatch = client.dispatch self._queue = client._queue self.loop = self._client.loop + self._disconnect = False + self._reconnect = client._reconnect + self._backoff = ExponentialBackoff() self._task = None + self._handled_exceptions = ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) @property def id(self): @@ -79,6 +94,33 @@ class Shard: def launch(self): self._task = self.loop.create_task(self.worker()) + def _cancel_task(self): + if self._task is not None and not self._task.done(): + self._task.cancel() + + async def close(self): + self._cancel_task() + await self.ws.close(code=1000) + + async def _handle_disconnect(self, e): + self._dispatch('disconnect') + if not self._reconnect: + self._queue.put_nowait(EventItem(EventType.close, self, e)) + return + + if self._client.is_closed(): + return + + if isinstance(e, ConnectionClosed): + if e.code != 1000: + self._queue.put_nowait(EventItem(EventType.close, self, e)) + return + + retry = self._backoff.delay() + log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) + await asyncio.sleep(retry) + self._queue.put_nowait(EventItem(EventType.reconnect, self, e)) + async def worker(self): while not self._client.is_closed(): try: @@ -87,14 +129,12 @@ class Shard: etype = EventType.resume if e.resume else EventType.identify self._queue.put_nowait(EventItem(etype, self, e)) break - except ConnectionClosed as e: - self._queue.put_nowait(EventItem(EventType.close, self, e)) + except self._handled_exceptions as e: + await self._handle_disconnect(e) break - async def reconnect(self, exc): - if self._task is not None and not self._task.done(): - self._task.cancel() - + async def reidentify(self, exc): + self._cancel_task() log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, session=self.ws.session_id, sequence=self.ws.sequence) @@ -102,6 +142,16 @@ class Shard: self.ws = await asyncio.wait_for(coro, timeout=180.0) self.launch() + async def reconnect(self): + self._cancel_task() + try: + coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) + self.ws = await asyncio.wait_for(coro, timeout=180.0) + except self._handled_exceptions as e: + await self._handle_disconnect(e) + else: + self.launch() + class AutoShardedClient(Client): """A client similar to :class:`Client` except it handles the complications of sharding for the user into a more manageable and transparent single @@ -235,15 +285,21 @@ class AutoShardedClient(Client): self._connection.shards_launched.set() - async def _connect(self): + async def connect(self, *, reconnect=True): + self._reconnect = reconnect await self.launch_shards() - while True: + while not self.is_closed(): item = await self._queue.get() if item.type == EventType.close: - raise item.error + await self.close() + if isinstance(item.error, ConnectionClosed) and item.error.code != 1000: + raise item.error + return elif item.type in (EventType.identify, EventType.resume): - await item.shard.reconnect(item.error) + await item.shard.reidentify(item.error) + elif item.type == EventType.reconnect: + await item.shard.reconnect() async def close(self): """|coro| @@ -261,7 +317,7 @@ class AutoShardedClient(Client): except Exception: pass - to_close = [asyncio.ensure_future(shard.ws.close(code=1000), loop=self.loop) for shard in self.shards.values()] + to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()] if to_close: await asyncio.wait(to_close)