Browse Source

Make every shard maintain its own reconnect loop

Previously if a disconnect happened the client would get in a bad state
and certain shards would be double sending due to unhandled exceptions
raising back to Client.connect and causing all shards to be reconnected
again.

This new code overrides Client.connect to have more finer control and
allow each individual shard to maintain its own reconnect loop and then
serially request reconnection to ensure that IDENTIFYs are not
overlapping.
pull/5164/head
Rapptz 5 years ago
parent
commit
f658fcf164
  1. 84
      discord/shard.py

84
discord/shard.py

@ -28,10 +28,13 @@ import asyncio
import itertools import itertools
import logging import logging
import aiohttp
from .state import AutoShardedConnectionState from .state import AutoShardedConnectionState
from .client import Client from .client import Client
from .backoff import ExponentialBackoff
from .gateway import * from .gateway import *
from .errors import ClientException, InvalidArgument, ConnectionClosed from .errors import ClientException, InvalidArgument, HTTPException, GatewayNotFound, ConnectionClosed
from . import utils from . import utils
from .enums import Status from .enums import Status
@ -39,8 +42,9 @@ log = logging.getLogger(__name__)
class EventType: class EventType:
close = 0 close = 0
resume = 1 reconnect = 1
identify = 2 resume = 2
identify = 3
class EventItem: class EventItem:
__slots__ = ('type', 'shard', 'error') __slots__ = ('type', 'shard', 'error')
@ -70,7 +74,18 @@ class Shard:
self._dispatch = client.dispatch self._dispatch = client.dispatch
self._queue = client._queue self._queue = client._queue
self.loop = self._client.loop self.loop = self._client.loop
self._disconnect = False
self._reconnect = client._reconnect
self._backoff = ExponentialBackoff()
self._task = None self._task = None
self._handled_exceptions = (
OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
)
@property @property
def id(self): def id(self):
@ -79,6 +94,33 @@ class Shard:
def launch(self): def launch(self):
self._task = self.loop.create_task(self.worker()) self._task = self.loop.create_task(self.worker())
def _cancel_task(self):
if self._task is not None and not self._task.done():
self._task.cancel()
async def close(self):
self._cancel_task()
await self.ws.close(code=1000)
async def _handle_disconnect(self, e):
self._dispatch('disconnect')
if not self._reconnect:
self._queue.put_nowait(EventItem(EventType.close, self, e))
return
if self._client.is_closed():
return
if isinstance(e, ConnectionClosed):
if e.code != 1000:
self._queue.put_nowait(EventItem(EventType.close, self, e))
return
retry = self._backoff.delay()
log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
await asyncio.sleep(retry)
self._queue.put_nowait(EventItem(EventType.reconnect, self, e))
async def worker(self): async def worker(self):
while not self._client.is_closed(): while not self._client.is_closed():
try: try:
@ -87,14 +129,12 @@ class Shard:
etype = EventType.resume if e.resume else EventType.identify etype = EventType.resume if e.resume else EventType.identify
self._queue.put_nowait(EventItem(etype, self, e)) self._queue.put_nowait(EventItem(etype, self, e))
break break
except ConnectionClosed as e: except self._handled_exceptions as e:
self._queue.put_nowait(EventItem(EventType.close, self, e)) await self._handle_disconnect(e)
break break
async def reconnect(self, exc): async def reidentify(self, exc):
if self._task is not None and not self._task.done(): self._cancel_task()
self._task.cancel()
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence) session=self.ws.session_id, sequence=self.ws.sequence)
@ -102,6 +142,16 @@ class Shard:
self.ws = await asyncio.wait_for(coro, timeout=180.0) self.ws = await asyncio.wait_for(coro, timeout=180.0)
self.launch() self.launch()
async def reconnect(self):
self._cancel_task()
try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
self.ws = await asyncio.wait_for(coro, timeout=180.0)
except self._handled_exceptions as e:
await self._handle_disconnect(e)
else:
self.launch()
class AutoShardedClient(Client): class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications """A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single of sharding for the user into a more manageable and transparent single
@ -235,15 +285,21 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set() self._connection.shards_launched.set()
async def _connect(self): async def connect(self, *, reconnect=True):
self._reconnect = reconnect
await self.launch_shards() await self.launch_shards()
while True: while not self.is_closed():
item = await self._queue.get() item = await self._queue.get()
if item.type == EventType.close: if item.type == EventType.close:
raise item.error await self.close()
if isinstance(item.error, ConnectionClosed) and item.error.code != 1000:
raise item.error
return
elif item.type in (EventType.identify, EventType.resume): elif item.type in (EventType.identify, EventType.resume):
await item.shard.reconnect(item.error) await item.shard.reidentify(item.error)
elif item.type == EventType.reconnect:
await item.shard.reconnect()
async def close(self): async def close(self):
"""|coro| """|coro|
@ -261,7 +317,7 @@ class AutoShardedClient(Client):
except Exception: except Exception:
pass pass
to_close = [asyncio.ensure_future(shard.ws.close(code=1000), loop=self.loop) for shard in self.shards.values()] to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()]
if to_close: if to_close:
await asyncio.wait(to_close) await asyncio.wait(to_close)

Loading…
Cancel
Save