diff --git a/discord/shard.py b/discord/shard.py index f817fb9ab..31777816b 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -28,10 +28,13 @@ import asyncio import itertools import logging +import aiohttp + from .state import AutoShardedConnectionState from .client import Client +from .backoff import ExponentialBackoff from .gateway import * -from .errors import ClientException, InvalidArgument, ConnectionClosed +from .errors import ClientException, InvalidArgument, HTTPException, GatewayNotFound, ConnectionClosed from . import utils from .enums import Status @@ -39,8 +42,9 @@ log = logging.getLogger(__name__) class EventType: close = 0 - resume = 1 - identify = 2 + reconnect = 1 + resume = 2 + identify = 3 class EventItem: __slots__ = ('type', 'shard', 'error') @@ -70,7 +74,18 @@ class Shard: self._dispatch = client.dispatch self._queue = client._queue self.loop = self._client.loop + self._disconnect = False + self._reconnect = client._reconnect + self._backoff = ExponentialBackoff() self._task = None + self._handled_exceptions = ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) @property def id(self): @@ -79,6 +94,33 @@ class Shard: def launch(self): self._task = self.loop.create_task(self.worker()) + def _cancel_task(self): + if self._task is not None and not self._task.done(): + self._task.cancel() + + async def close(self): + self._cancel_task() + await self.ws.close(code=1000) + + async def _handle_disconnect(self, e): + self._dispatch('disconnect') + if not self._reconnect: + self._queue.put_nowait(EventItem(EventType.close, self, e)) + return + + if self._client.is_closed(): + return + + if isinstance(e, ConnectionClosed): + if e.code != 1000: + self._queue.put_nowait(EventItem(EventType.close, self, e)) + return + + retry = self._backoff.delay() + log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) + await asyncio.sleep(retry) + self._queue.put_nowait(EventItem(EventType.reconnect, self, e)) + async def worker(self): while not self._client.is_closed(): try: @@ -87,14 +129,12 @@ class Shard: etype = EventType.resume if e.resume else EventType.identify self._queue.put_nowait(EventItem(etype, self, e)) break - except ConnectionClosed as e: - self._queue.put_nowait(EventItem(EventType.close, self, e)) + except self._handled_exceptions as e: + await self._handle_disconnect(e) break - async def reconnect(self, exc): - if self._task is not None and not self._task.done(): - self._task.cancel() - + async def reidentify(self, exc): + self._cancel_task() log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, session=self.ws.session_id, sequence=self.ws.sequence) @@ -102,6 +142,16 @@ class Shard: self.ws = await asyncio.wait_for(coro, timeout=180.0) self.launch() + async def reconnect(self): + self._cancel_task() + try: + coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) + self.ws = await asyncio.wait_for(coro, timeout=180.0) + except self._handled_exceptions as e: + await self._handle_disconnect(e) + else: + self.launch() + class AutoShardedClient(Client): """A client similar to :class:`Client` except it handles the complications of sharding for the user into a more manageable and transparent single @@ -235,15 +285,21 @@ class AutoShardedClient(Client): self._connection.shards_launched.set() - async def _connect(self): + async def connect(self, *, reconnect=True): + self._reconnect = reconnect await self.launch_shards() - while True: + while not self.is_closed(): item = await self._queue.get() if item.type == EventType.close: - raise item.error + await self.close() + if isinstance(item.error, ConnectionClosed) and item.error.code != 1000: + raise item.error + return elif item.type in (EventType.identify, EventType.resume): - await item.shard.reconnect(item.error) + await item.shard.reidentify(item.error) + elif item.type == EventType.reconnect: + await item.shard.reconnect() async def close(self): """|coro| @@ -261,7 +317,7 @@ class AutoShardedClient(Client): except Exception: pass - to_close = [asyncio.ensure_future(shard.ws.close(code=1000), loop=self.loop) for shard in self.shards.values()] + to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()] if to_close: await asyncio.wait(to_close)