From b8154e365ff584438a8d42354e56881e550bb72e Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 7 Apr 2020 21:53:55 -0400 Subject: [PATCH] Rewrite gateway to use aiohttp instead of websockets --- discord/__main__.py | 2 - discord/client.py | 11 +-- discord/errors.py | 9 +-- discord/ext/tasks/__init__.py | 3 - discord/gateway.py | 132 +++++++++++++++++++--------------- discord/http.py | 11 +++ discord/shard.py | 22 +----- requirements.txt | 1 - 8 files changed, 98 insertions(+), 93 deletions(-) diff --git a/discord/__main__.py b/discord/__main__.py index 102ca30c4..708547484 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -31,7 +31,6 @@ from pathlib import Path import discord import pkg_resources import aiohttp -import websockets import platform def show_version(): @@ -46,7 +45,6 @@ def show_version(): entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version)) entries.append('- aiohttp v{0.__version__}'.format(aiohttp)) - entries.append('- websockets v{0.__version__}'.format(websockets)) uname = platform.uname() entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) print('\n'.join(entries)) diff --git a/discord/client.py b/discord/client.py index 0fcdcd488..859315699 100644 --- a/discord/client.py +++ b/discord/client.py @@ -32,7 +32,6 @@ import sys import traceback import aiohttp -import websockets from .user import User, Profile from .asset import Asset @@ -497,9 +496,7 @@ class Client: GatewayNotFound, ConnectionClosed, aiohttp.ClientError, - asyncio.TimeoutError, - websockets.InvalidHandshake, - websockets.WebSocketProtocolError) as exc: + asyncio.TimeoutError) as exc: self.dispatch('disconnect') if not reconnect: @@ -632,7 +629,11 @@ class Client: _cleanup_loop(loop) if not future.cancelled(): - return future.result() + try: + return future.result() + except KeyboardInterrupt: + # I am unsure why this gets raised here but suppress it anyway + return None # properties diff --git a/discord/errors.py b/discord/errors.py index 7ab73e9d4..f8da42d1c 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -159,10 +159,11 @@ class ConnectionClosed(ClientException): shard_id: Optional[:class:`int`] The shard ID that got closed if applicable. """ - def __init__(self, original, *, shard_id): + def __init__(self, socket, *, shard_id): # This exception is just the same exception except # reconfigured to subclass ClientException for users - self.code = original.code - self.reason = original.reason + self.code = socket.close_code + # aiohttp doesn't seem to consistently provide close reason + self.reason = '' self.shard_id = shard_id - super().__init__(str(original)) + super().__init__('Shard ID %s WebSocket closed with %s' % (self.shard_id, self.code)) diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 3fa1cb01b..7921b095d 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -27,7 +27,6 @@ DEALINGS IN THE SOFTWARE. import asyncio import datetime import aiohttp -import websockets import discord import inspect import logging @@ -58,8 +57,6 @@ class Loop: discord.ConnectionClosed, aiohttp.ClientError, asyncio.TimeoutError, - websockets.InvalidHandshake, - websockets.WebSocketProtocolError, ) self._before_loop = None diff --git a/discord/gateway.py b/discord/gateway.py index e5cbfe533..59dd3c1a8 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -36,7 +36,7 @@ import threading import traceback import zlib -import websockets +import aiohttp from . import utils from .activity import BaseActivity @@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception): self.resume = resume self.op = 'RESUME' if resume else 'IDENTIFY' +class WebSocketClosure(Exception): + """An exception to make up for the fact that aiohttp doesn't signal closure.""" + pass + EventListener = namedtuple('EventListener', 'predicate event result future') class KeepAliveHandler(threading.Thread): @@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler): self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) -class DiscordWebSocket(websockets.client.WebSocketClientProtocol): - """Implements a WebSocket for Discord's gateway v6. +# Monkey patch certain things from the aiohttp websocket code +# Check this whenever we update dependencies. +OLD_CLOSE = aiohttp.ClientWebSocketResponse.close + +async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool: + return await OLD_CLOSE(self, code=code, message=message) - This is created through :func:`create_main_websocket`. Library - users should never create this manually. +aiohttp.ClientWebSocketResponse.close = _new_ws_close + +class DiscordWebSocket: + """Implements a WebSocket for Discord's gateway v6. Attributes ----------- @@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): HEARTBEAT_ACK = 11 GUILD_SYNC = 12 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.max_size = None + def __init__(self, socket, *, loop): + self.socket = socket + self.loop = loop + # an empty dispatcher to prevent crashes self._dispatch = lambda *args: None # generic event listeners @@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): self._zlib = zlib.decompressobj() self._buffer = bytearray() + @property + def open(self): + return not self.socket.closed + @classmethod - async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False): + async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False): """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. """ - gateway = await client.http.get_gateway() - ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) + gateway = gateway or await client.http.get_gateway() + socket = await client.http.ws_connect(gateway) + ws = cls(socket, loop=client.loop) # dynamically add attributes needed ws.token = client.http.token @@ -267,14 +283,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): return ws await ws.resume() - try: - await ws.ensure_open() - except websockets.exceptions.ConnectionClosed: - # ws got closed so let's just do a regular IDENTIFY connect. - log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id) - return await cls.from_client(client, shard_id=shard_id) - else: - return ws + return ws def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. @@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency - def _can_handle_close(self, code): - return code not in (1000, 4004, 4010, 4011) + def _can_handle_close(self): + return self.socket.close_code not in (1000, 4004, 4010, 4011) async def poll_event(self): """Polls for a DISPATCH event and handles the general gateway loop. @@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): The websocket connection was terminated for unhandled reasons. """ try: - msg = await self.recv() - await self.received_message(msg) - except websockets.exceptions.ConnectionClosed as exc: - if self._can_handle_close(exc.code): - log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason) - raise ReconnectWebSocket(self.shard_id) from exc - else: - log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason) - raise ConnectionClosed(exc, shard_id=self.shard_id) from exc + msg = await self.socket.receive() + if msg.type is aiohttp.WSMsgType.TEXT: + await self.received_message(msg.data) + elif msg.type is aiohttp.WSMsgType.BINARY: + await self.received_message(msg.data) + elif msg.type is aiohttp.WSMsgType.ERROR: + log.debug('Received %s', msg) + raise msg.data + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): + log.debug('Received %s', msg) + raise WebSocketClosure('Unexpected WebSocket closure.') + except WebSocketClosure as e: + if self._can_handle_close(): + log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code) + raise ReconnectWebSocket(self.shard_id) from e + elif self.socket.close_code is not None: + log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code) + raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e async def send(self, data): self._dispatch('socket_raw_send', data) - await super().send(data) + await self.socket.send_str(data) async def send_as_json(self, data): try: await self.send(utils.to_json(data)) - except websockets.exceptions.ConnectionClosed as exc: - if not self._can_handle_close(exc.code): - raise ConnectionClosed(exc, shard_id=self.shard_id) from exc + except RuntimeError as exc: + if not self._can_handle_close(): + raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): if activity is not None: @@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): log.debug('Updating our voice state to %s.', payload) await self.send_as_json(payload) - async def close(self, code=4000, reason=''): - if self._keep_alive: - self._keep_alive.stop() - - await super().close(code, reason) - - async def close_connection(self, *args, **kwargs): + async def close(self, code=4000): if self._keep_alive: self._keep_alive.stop() - await super().close_connection(*args, **kwargs) + await self.socket.close(code=code) -class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): +class DiscordVoiceWebSocket: """Implements the websocket protocol for handling voice connections. Attributes @@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): CLIENT_CONNECT = 12 CLIENT_DISCONNECT = 13 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.max_size = None + def __init__(self, socket): + self.ws = socket self._keep_alive = None async def send_as_json(self, data): log.debug('Sending voice websocket frame: %s.', data) - await self.send(utils.to_json(data)) + await self.ws.send_str(utils.to_json(data)) async def resume(self): state = self._connection @@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): async def from_client(cls, client, *, resume=False): """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' - ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) + http = client._state.http + socket = await http.ws_connect(gateway) + ws = cls(socket) ws.gateway = gateway ws._connection = client ws._max_heartbeat_timeout = 60.0 @@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): await self.speak(False) async def poll_event(self): - try: - msg = await asyncio.wait_for(self.recv(), timeout=30.0) - await self.received_message(json.loads(msg)) - except websockets.exceptions.ConnectionClosed as exc: - raise ConnectionClosed(exc, shard_id=None) from exc - - async def close_connection(self, *args, **kwargs): - if self._keep_alive: + # This exception is handled up the chain + msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) + if msg.type is aiohttp.WSMsgType.TEXT: + await self.received_message(json.loads(msg.data)) + elif msg.type is aiohttp.WSMsgType.ERROR: + log.debug('Received %s', msg) + raise ConnectionClosed(self.ws, shard_id=None) from msg.data + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): + log.debug('Received %s', msg) + raise ConnectionClosed(self.ws, shard_id=None) + + async def close(self, code=1000): + if self._keep_alive is not None: self._keep_alive.stop() - await super().close_connection(*args, **kwargs) + await self.ws.close(code=code) diff --git a/discord/http.py b/discord/http.py index 39fcaed0e..a9da4267c 100644 --- a/discord/http.py +++ b/discord/http.py @@ -111,6 +111,17 @@ class HTTPClient: if self.__session.closed: self.__session = aiohttp.ClientSession(connector=self.connector) + async def ws_connect(self, url): + kwargs = { + 'proxy_auth': self.proxy_auth, + 'proxy': self.proxy, + 'max_msg_size': 0, + 'timeout': 30.0, + 'autoclose': False, + } + + return await self.__session.ws_connect(url, **kwargs) + async def request(self, route, *, files=None, **kwargs): bucket = route.bucket method = route.method diff --git a/discord/shard.py b/discord/shard.py index f2feaecb3..1e34a56c6 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -28,8 +28,6 @@ import asyncio import itertools import logging -import websockets - from .state import AutoShardedConnectionState from .client import Client from .gateway import * @@ -191,31 +189,13 @@ class AutoShardedClient(Client): async def launch_shard(self, gateway, shard_id): try: - coro = websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket, compression=None) + coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id) ws = await asyncio.wait_for(coro, timeout=180.0) except Exception: log.info('Failed to connect for shard_id: %s. Retrying...', shard_id) await asyncio.sleep(5.0) return await self.launch_shard(gateway, shard_id) - ws.token = self.http.token - ws._connection = self._connection - ws._discord_parsers = self._connection.parsers - ws._dispatch = self.dispatch - ws.gateway = gateway - ws.shard_id = shard_id - ws.shard_count = self.shard_count - ws._max_heartbeat_timeout = self._connection.heartbeat_timeout - - try: - # OP HELLO - await asyncio.wait_for(ws.poll_event(), timeout=180.0) - await asyncio.wait_for(ws.identify(), timeout=180.0) - except asyncio.TimeoutError: - log.info('Timed out when connecting for shard_id: %s. Retrying...', shard_id) - await asyncio.sleep(5.0) - return await self.launch_shard(gateway, shard_id) - # keep reading the shard while others connect self.shards[shard_id] = ret = Shard(ws, self) ret.launch() diff --git a/requirements.txt b/requirements.txt index 8dfc53013..25c9da588 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ aiohttp>=3.6.0,<3.7.0 -websockets>=6.0,!=7.0,!=8.0,!=8.0.1,<9.0