From 3a1a215f8b2fbbd11480ba9708b2e548a3cdba6b Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 5 Aug 2020 04:27:11 -0400 Subject: [PATCH] Propagate manual close codes to socket subclass aiohttp seems to not set it during its state machine flow --- discord/errors.py | 4 ++-- discord/gateway.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/discord/errors.py b/discord/errors.py index f8da42d1c..fec1f45db 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -159,10 +159,10 @@ class ConnectionClosed(ClientException): shard_id: Optional[:class:`int`] The shard ID that got closed if applicable. """ - def __init__(self, socket, *, shard_id): + def __init__(self, socket, *, shard_id, code=None): # This exception is just the same exception except # reconfigured to subclass ClientException for users - self.code = socket.close_code + self.code = code or socket.close_code # aiohttp doesn't seem to consistently provide close reason self.reason = '' self.shard_id = shard_id diff --git a/discord/gateway.py b/discord/gateway.py index d54ab5970..6fde1a970 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -239,6 +239,7 @@ class DiscordWebSocket: self.sequence = None self._zlib = zlib.decompressobj() self._buffer = bytearray() + self._close_code = None @property def open(self): @@ -484,7 +485,8 @@ class DiscordWebSocket: return float('inf') if heartbeat is None else heartbeat.latency def _can_handle_close(self): - return self.socket.close_code not in (1000, 4004, 4010, 4011) + code = self._close_code or self.socket.close_code + return code not in (1000, 4004, 4010, 4011) async def poll_event(self): """Polls for a DISPATCH event and handles the general gateway loop. @@ -516,12 +518,13 @@ class DiscordWebSocket: log.info('Timed out receiving packet. Attempting a reconnect.') raise ReconnectWebSocket(self.shard_id) from None + code = self._close_code or self.socket.close_code if self._can_handle_close(): - log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code) + log.info('Websocket closed with %s, attempting a reconnect.', code) raise ReconnectWebSocket(self.shard_id) from None elif self.socket.close_code is not None: - log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code) - raise ConnectionClosed(self.socket, shard_id=self.shard_id) from None + log.info('Websocket closed with %s, cannot reconnect.', code) + raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None async def send(self, data): self._dispatch('socket_raw_send', data) @@ -604,6 +607,7 @@ class DiscordWebSocket: self._keep_alive.stop() self._keep_alive = None + self._close_code = code await self.socket.close(code=code) class DiscordVoiceWebSocket: @@ -654,6 +658,7 @@ class DiscordVoiceWebSocket: self.ws = socket self.loop = loop self._keep_alive = None + self._close_code = None async def send_as_json(self, data): log.debug('Sending voice websocket frame: %s.', data) @@ -820,10 +825,11 @@ class DiscordVoiceWebSocket: raise ConnectionClosed(self.ws, shard_id=None) from msg.data elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): log.debug('Received %s', msg) - raise ConnectionClosed(self.ws, shard_id=None) + raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) async def close(self, code=1000): if self._keep_alive is not None: self._keep_alive.stop() + self._close_code = code await self.ws.close(code=code)