Browse Source

Rewrite of AutoShardedClient to prevent overlapping identify

This is experimental and I'm unsure if it actually works
pull/5164/head
Rapptz 5 years ago
parent
commit
09ecb16680
  1. 9
      discord/client.py
  2. 22
      discord/gateway.py
  3. 103
      discord/shard.py
  4. 2
      discord/state.py

9
discord/client.py

@ -453,11 +453,14 @@ class Client:
while True: while True:
try: try:
await self.ws.poll_event() await self.ws.poll_event()
except ResumeWebSocket: except ReconnectWebSocket as e:
log.info('Got a request to RESUME the websocket.') log.info('Got a request to %s the websocket.', e.op)
self.dispatch('disconnect') self.dispatch('disconnect')
if not e.resume:
await asyncio.sleep(5.0)
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id, coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
sequence=self.ws.sequence, resume=True) sequence=self.ws.sequence, resume=e.resume)
self.ws = await asyncio.wait_for(coro, timeout=180.0) self.ws = await asyncio.wait_for(coro, timeout=180.0)
async def connect(self, *, reconnect=True): async def connect(self, *, reconnect=True):

22
discord/gateway.py

@ -50,13 +50,15 @@ __all__ = (
'KeepAliveHandler', 'KeepAliveHandler',
'VoiceKeepAliveHandler', 'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket', 'DiscordVoiceWebSocket',
'ResumeWebSocket', 'ReconnectWebSocket',
) )
class ResumeWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to initialise via RESUME opcode instead of IDENTIFY.""" """Signals to safely reconnect the websocket."""
def __init__(self, shard_id): def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id self.shard_id = shard_id
self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
EventListener = namedtuple('EventListener', 'predicate event result future') EventListener = namedtuple('EventListener', 'predicate event result future')
@ -385,7 +387,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# internal exception signalling to reconnect. # internal exception signalling to reconnect.
log.debug('Received RECONNECT opcode.') log.debug('Received RECONNECT opcode.')
await self.close() await self.close()
raise ResumeWebSocket(self.shard_id) raise ReconnectWebSocket(self.shard_id)
if op == self.HEARTBEAT_ACK: if op == self.HEARTBEAT_ACK:
self._keep_alive.ack() self._keep_alive.ack()
@ -406,16 +408,14 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
if op == self.INVALIDATE_SESSION: if op == self.INVALIDATE_SESSION:
if data is True: if data is True:
await asyncio.sleep(5.0)
await self.close() await self.close()
raise ResumeWebSocket(self.shard_id) raise ReconnectWebSocket(self.shard_id)
self.sequence = None self.sequence = None
self.session_id = None self.session_id = None
log.info('Shard ID %s session has been invalidated.', self.shard_id) log.info('Shard ID %s session has been invalidated.', self.shard_id)
await asyncio.sleep(5.0) await self.close(code=1000)
await self.identify() raise ReconnectWebSocket(self.shard_id, resume=False)
return
log.warning('Unknown OP code %s.', op) log.warning('Unknown OP code %s.', op)
return return
@ -489,7 +489,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
except websockets.exceptions.ConnectionClosed as exc: except websockets.exceptions.ConnectionClosed as exc:
if self._can_handle_close(exc.code): if self._can_handle_close(exc.code):
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason) log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason)
raise ResumeWebSocket(self.shard_id) from exc raise ReconnectWebSocket(self.shard_id) from exc
else: else:
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason) log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason)
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc raise ConnectionClosed(exc, shard_id=self.shard_id) from exc

103
discord/shard.py

@ -33,61 +33,58 @@ import websockets
from .state import AutoShardedConnectionState from .state import AutoShardedConnectionState
from .client import Client from .client import Client
from .gateway import * from .gateway import *
from .errors import ClientException, InvalidArgument from .errors import ClientException, InvalidArgument, ConnectionClosed
from . import utils from . import utils
from .enums import Status from .enums import Status
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EventType:
Close = 0
Resume = 1
Identify = 2
class Shard: class Shard:
def __init__(self, ws, client): def __init__(self, ws, client):
self.ws = ws self.ws = ws
self._client = client self._client = client
self._dispatch = client.dispatch self._dispatch = client.dispatch
self._queue = client._queue
self.loop = self._client.loop self.loop = self._client.loop
self._current = self.loop.create_future() self._task = None
self._current.set_result(None) # we just need an already done future
self._pending = asyncio.Event()
self._pending_task = None
@property @property
def id(self): def id(self):
return self.ws.shard_id return self.ws.shard_id
def is_pending(self): def launch(self):
return not self._pending.is_set() self._task = self.loop.create_task(self.worker())
def complete_pending_reads(self):
self._pending.set()
async def _pending_reads(self):
try:
while self.is_pending():
await self.poll()
except asyncio.CancelledError:
pass
def launch_pending_reads(self):
self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop)
def wait(self):
return self._pending_task
async def poll(self): async def worker(self):
try: while True:
await self.ws.poll_event() try:
except ResumeWebSocket: await self.ws.poll_event()
log.info('Got a request to RESUME the websocket at Shard ID %s.', self.id) except ReconnectWebSocket as e:
coro = DiscordWebSocket.from_client(self._client, resume=True, shard_id=self.id, etype = EventType.resume if e.resume else EventType.identify
session=self.ws.session_id, sequence=self.ws.sequence) self._queue.put_nowait((etype, self, e))
self._dispatch('disconnect') break
self.ws = await asyncio.wait_for(coro, timeout=180.0) except ConnectionClosed as e:
self._queue.put_nowait((EventType.close, self, e))
def get_future(self): break
if self._current.done():
self._current = asyncio.ensure_future(self.poll(), loop=self.loop) async def reconnect(self, exc):
if self._task is not None and not self._task.done():
self._task.cancel()
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
if not exc.resume:
await asyncio.sleep(5.0)
return self._current coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
self._dispatch('disconnect')
self.ws = await asyncio.wait_for(coro, timeout=180.0)
self.launch()
class AutoShardedClient(Client): class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications """A client similar to :class:`Client` except it handles the complications
@ -134,6 +131,7 @@ class AutoShardedClient(Client):
# the key is the shard_id # the key is the shard_id
self.shards = {} self.shards = {}
self._connection._get_websocket = self._get_websocket self._connection._get_websocket = self._get_websocket
self._queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id=None, *, shard_id=None): def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None: if shard_id is None:
@ -220,8 +218,10 @@ class AutoShardedClient(Client):
# keep reading the shard while others connect # keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self) self.shards[shard_id] = ret = Shard(ws, self)
ret.launch_pending_reads() ret.launch()
await asyncio.sleep(5.0)
if len(self.shards) == self.shard_count:
self._connection.shards_launched.set()
async def launch_shards(self): async def launch_shards(self):
if self.shard_count is None: if self.shard_count is None:
@ -234,26 +234,29 @@ class AutoShardedClient(Client):
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
self._connection.shard_ids = shard_ids self._connection.shard_ids = shard_ids
last_shard_id = shard_ids[-1]
for shard_id in shard_ids: for shard_id in shard_ids:
await self.launch_shard(gateway, shard_id) await self.launch_shard(gateway, shard_id)
if shard_id != last_shard_id:
await asyncio.sleep(5.0)
shards_to_wait_for = [] # shards_to_wait_for = []
for shard in self.shards.values(): # for shard in self.shards.values():
shard.complete_pending_reads() # shard.complete_pending_reads()
shards_to_wait_for.append(shard.wait()) # shards_to_wait_for.append(shard.wait())
# wait for all pending tasks to finish # # wait for all pending tasks to finish
await utils.sane_wait_for(shards_to_wait_for, timeout=300.0) # await utils.sane_wait_for(shards_to_wait_for, timeout=300.0)
async def _connect(self): async def _connect(self):
await self.launch_shards() await self.launch_shards()
while True: while True:
pollers = [shard.get_future() for shard in self.shards.values()] etype, shard, exc = await self._queue.get()
done, _ = await asyncio.wait(pollers, return_when=asyncio.FIRST_COMPLETED) if etype == EventType.close:
for f in done: raise exc
# we wanna re-raise to the main Client.connect handler if applicable elif etype in (EventType.identify, EventType.resume):
f.result() await shard.reconnect(exc)
async def close(self): async def close(self):
"""|coro| """|coro|

2
discord/state.py

@ -1047,6 +1047,7 @@ class AutoShardedConnectionState(ConnectionState):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._ready_task = None self._ready_task = None
self.shard_ids = () self.shard_ids = ()
self.shards_launched = asyncio.Event()
async def chunker(self, guild_id, query='', limit=0, *, shard_id=None, nonce=None): async def chunker(self, guild_id, query='', limit=0, *, shard_id=None, nonce=None):
ws = self._get_websocket(guild_id, shard_id=shard_id) ws = self._get_websocket(guild_id, shard_id=shard_id)
@ -1073,6 +1074,7 @@ class AutoShardedConnectionState(ConnectionState):
log.info('Finished requesting guild member chunks for %d guilds.', len(guilds)) log.info('Finished requesting guild member chunks for %d guilds.', len(guilds))
async def _delay_ready(self): async def _delay_ready(self):
await self.shards_launched.wait()
launch = self._ready_state.launch launch = self._ready_state.launch
while True: while True:
# this snippet of code is basically waiting 2 * shard_ids seconds # this snippet of code is basically waiting 2 * shard_ids seconds

Loading…
Cancel
Save