Browse Source

Rewrite gateway to use aiohttp instead of websockets

pull/5164/head
Rapptz 5 years ago
parent
commit
b8154e365f
  1. 2
      discord/__main__.py
  2. 11
      discord/client.py
  3. 9
      discord/errors.py
  4. 3
      discord/ext/tasks/__init__.py
  5. 132
      discord/gateway.py
  6. 11
      discord/http.py
  7. 22
      discord/shard.py
  8. 1
      requirements.txt

2
discord/__main__.py

@ -31,7 +31,6 @@ from pathlib import Path
import discord import discord
import pkg_resources import pkg_resources
import aiohttp import aiohttp
import websockets
import platform import platform
def show_version(): def show_version():
@ -46,7 +45,6 @@ def show_version():
entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version)) entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version))
entries.append('- aiohttp v{0.__version__}'.format(aiohttp)) entries.append('- aiohttp v{0.__version__}'.format(aiohttp))
entries.append('- websockets v{0.__version__}'.format(websockets))
uname = platform.uname() uname = platform.uname()
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
print('\n'.join(entries)) print('\n'.join(entries))

11
discord/client.py

@ -32,7 +32,6 @@ import sys
import traceback import traceback
import aiohttp import aiohttp
import websockets
from .user import User, Profile from .user import User, Profile
from .asset import Asset from .asset import Asset
@ -497,9 +496,7 @@ class Client:
GatewayNotFound, GatewayNotFound,
ConnectionClosed, ConnectionClosed,
aiohttp.ClientError, aiohttp.ClientError,
asyncio.TimeoutError, asyncio.TimeoutError) as exc:
websockets.InvalidHandshake,
websockets.WebSocketProtocolError) as exc:
self.dispatch('disconnect') self.dispatch('disconnect')
if not reconnect: if not reconnect:
@ -632,7 +629,11 @@ class Client:
_cleanup_loop(loop) _cleanup_loop(loop)
if not future.cancelled(): if not future.cancelled():
return future.result() try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
# properties # properties

9
discord/errors.py

@ -159,10 +159,11 @@ class ConnectionClosed(ClientException):
shard_id: Optional[:class:`int`] shard_id: Optional[:class:`int`]
The shard ID that got closed if applicable. The shard ID that got closed if applicable.
""" """
def __init__(self, original, *, shard_id): def __init__(self, socket, *, shard_id):
# This exception is just the same exception except # This exception is just the same exception except
# reconfigured to subclass ClientException for users # reconfigured to subclass ClientException for users
self.code = original.code self.code = socket.close_code
self.reason = original.reason # aiohttp doesn't seem to consistently provide close reason
self.reason = ''
self.shard_id = shard_id self.shard_id = shard_id
super().__init__(str(original)) super().__init__('Shard ID %s WebSocket closed with %s' % (self.shard_id, self.code))

3
discord/ext/tasks/__init__.py

@ -27,7 +27,6 @@ DEALINGS IN THE SOFTWARE.
import asyncio import asyncio
import datetime import datetime
import aiohttp import aiohttp
import websockets
import discord import discord
import inspect import inspect
import logging import logging
@ -58,8 +57,6 @@ class Loop:
discord.ConnectionClosed, discord.ConnectionClosed,
aiohttp.ClientError, aiohttp.ClientError,
asyncio.TimeoutError, asyncio.TimeoutError,
websockets.InvalidHandshake,
websockets.WebSocketProtocolError,
) )
self._before_loop = None self._before_loop = None

132
discord/gateway.py

@ -36,7 +36,7 @@ import threading
import traceback import traceback
import zlib import zlib
import websockets import aiohttp
from . import utils from . import utils
from .activity import BaseActivity from .activity import BaseActivity
@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception):
self.resume = resume self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY' self.op = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
pass
EventListener = namedtuple('EventListener', 'predicate event result future') EventListener = namedtuple('EventListener', 'predicate event result future')
class KeepAliveHandler(threading.Thread): class KeepAliveHandler(threading.Thread):
@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # Monkey patch certain things from the aiohttp websocket code
"""Implements a WebSocket for Discord's gateway v6. # Check this whenever we update dependencies.
OLD_CLOSE = aiohttp.ClientWebSocketResponse.close
async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool:
return await OLD_CLOSE(self, code=code, message=message)
This is created through :func:`create_main_websocket`. Library aiohttp.ClientWebSocketResponse.close = _new_ws_close
users should never create this manually.
class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6.
Attributes Attributes
----------- -----------
@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
HEARTBEAT_ACK = 11 HEARTBEAT_ACK = 11
GUILD_SYNC = 12 GUILD_SYNC = 12
def __init__(self, *args, **kwargs): def __init__(self, socket, *, loop):
super().__init__(*args, **kwargs) self.socket = socket
self.max_size = None self.loop = loop
# an empty dispatcher to prevent crashes # an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None self._dispatch = lambda *args: None
# generic event listeners # generic event listeners
@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
self._zlib = zlib.decompressobj() self._zlib = zlib.decompressobj()
self._buffer = bytearray() self._buffer = bytearray()
@property
def open(self):
return not self.socket.closed
@classmethod @classmethod
async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False): async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
""" """
gateway = await client.http.get_gateway() gateway = gateway or await client.http.get_gateway()
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) socket = await client.http.ws_connect(gateway)
ws = cls(socket, loop=client.loop)
# dynamically add attributes needed # dynamically add attributes needed
ws.token = client.http.token ws.token = client.http.token
@ -267,14 +283,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
return ws return ws
await ws.resume() await ws.resume()
try: return ws
await ws.ensure_open()
except websockets.exceptions.ConnectionClosed:
# ws got closed so let's just do a regular IDENTIFY connect.
log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id)
return await cls.from_client(client, shard_id=shard_id)
else:
return ws
def wait_for(self, event, predicate, result=None): def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate. """Waits for a DISPATCH'd event that meets the predicate.
@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
def _can_handle_close(self, code): def _can_handle_close(self):
return code not in (1000, 4004, 4010, 4011) return self.socket.close_code not in (1000, 4004, 4010, 4011)
async def poll_event(self): async def poll_event(self):
"""Polls for a DISPATCH event and handles the general gateway loop. """Polls for a DISPATCH event and handles the general gateway loop.
@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
The websocket connection was terminated for unhandled reasons. The websocket connection was terminated for unhandled reasons.
""" """
try: try:
msg = await self.recv() msg = await self.socket.receive()
await self.received_message(msg) if msg.type is aiohttp.WSMsgType.TEXT:
except websockets.exceptions.ConnectionClosed as exc: await self.received_message(msg.data)
if self._can_handle_close(exc.code): elif msg.type is aiohttp.WSMsgType.BINARY:
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason) await self.received_message(msg.data)
raise ReconnectWebSocket(self.shard_id) from exc elif msg.type is aiohttp.WSMsgType.ERROR:
else: log.debug('Received %s', msg)
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason) raise msg.data
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
log.debug('Received %s', msg)
raise WebSocketClosure('Unexpected WebSocket closure.')
except WebSocketClosure as e:
if self._can_handle_close():
log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code)
raise ReconnectWebSocket(self.shard_id) from e
elif self.socket.close_code is not None:
log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e
async def send(self, data): async def send(self, data):
self._dispatch('socket_raw_send', data) self._dispatch('socket_raw_send', data)
await super().send(data) await self.socket.send_str(data)
async def send_as_json(self, data): async def send_as_json(self, data):
try: try:
await self.send(utils.to_json(data)) await self.send(utils.to_json(data))
except websockets.exceptions.ConnectionClosed as exc: except RuntimeError as exc:
if not self._can_handle_close(exc.code): if not self._can_handle_close():
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
if activity is not None: if activity is not None:
@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
log.debug('Updating our voice state to %s.', payload) log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload) await self.send_as_json(payload)
async def close(self, code=4000, reason=''): async def close(self, code=4000):
if self._keep_alive:
self._keep_alive.stop()
await super().close(code, reason)
async def close_connection(self, *args, **kwargs):
if self._keep_alive: if self._keep_alive:
self._keep_alive.stop() self._keep_alive.stop()
await super().close_connection(*args, **kwargs) await self.socket.close(code=code)
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections. """Implements the websocket protocol for handling voice connections.
Attributes Attributes
@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
CLIENT_CONNECT = 12 CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
def __init__(self, *args, **kwargs): def __init__(self, socket):
super().__init__(*args, **kwargs) self.ws = socket
self.max_size = None
self._keep_alive = None self._keep_alive = None
async def send_as_json(self, data): async def send_as_json(self, data):
log.debug('Sending voice websocket frame: %s.', data) log.debug('Sending voice websocket frame: %s.', data)
await self.send(utils.to_json(data)) await self.ws.send_str(utils.to_json(data))
async def resume(self): async def resume(self):
state = self._connection state = self._connection
@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
async def from_client(cls, client, *, resume=False): async def from_client(cls, client, *, resume=False):
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) http = client._state.http
socket = await http.ws_connect(gateway)
ws = cls(socket)
ws.gateway = gateway ws.gateway = gateway
ws._connection = client ws._connection = client
ws._max_heartbeat_timeout = 60.0 ws._max_heartbeat_timeout = 60.0
@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
await self.speak(False) await self.speak(False)
async def poll_event(self): async def poll_event(self):
try: # This exception is handled up the chain
msg = await asyncio.wait_for(self.recv(), timeout=30.0) msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
await self.received_message(json.loads(msg)) if msg.type is aiohttp.WSMsgType.TEXT:
except websockets.exceptions.ConnectionClosed as exc: await self.received_message(json.loads(msg.data))
raise ConnectionClosed(exc, shard_id=None) from exc elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
async def close_connection(self, *args, **kwargs): raise ConnectionClosed(self.ws, shard_id=None) from msg.data
if self._keep_alive: elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None)
async def close(self, code=1000):
if self._keep_alive is not None:
self._keep_alive.stop() self._keep_alive.stop()
await super().close_connection(*args, **kwargs) await self.ws.close(code=code)

11
discord/http.py

@ -111,6 +111,17 @@ class HTTPClient:
if self.__session.closed: if self.__session.closed:
self.__session = aiohttp.ClientSession(connector=self.connector) self.__session = aiohttp.ClientSession(connector=self.connector)
async def ws_connect(self, url):
kwargs = {
'proxy_auth': self.proxy_auth,
'proxy': self.proxy,
'max_msg_size': 0,
'timeout': 30.0,
'autoclose': False,
}
return await self.__session.ws_connect(url, **kwargs)
async def request(self, route, *, files=None, **kwargs): async def request(self, route, *, files=None, **kwargs):
bucket = route.bucket bucket = route.bucket
method = route.method method = route.method

22
discord/shard.py

@ -28,8 +28,6 @@ import asyncio
import itertools import itertools
import logging import logging
import websockets
from .state import AutoShardedConnectionState from .state import AutoShardedConnectionState
from .client import Client from .client import Client
from .gateway import * from .gateway import *
@ -191,31 +189,13 @@ class AutoShardedClient(Client):
async def launch_shard(self, gateway, shard_id): async def launch_shard(self, gateway, shard_id):
try: try:
coro = websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket, compression=None) coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0) ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception: except Exception:
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id) log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0) await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id) return await self.launch_shard(gateway, shard_id)
ws.token = self.http.token
ws._connection = self._connection
ws._discord_parsers = self._connection.parsers
ws._dispatch = self.dispatch
ws.gateway = gateway
ws.shard_id = shard_id
ws.shard_count = self.shard_count
ws._max_heartbeat_timeout = self._connection.heartbeat_timeout
try:
# OP HELLO
await asyncio.wait_for(ws.poll_event(), timeout=180.0)
await asyncio.wait_for(ws.identify(), timeout=180.0)
except asyncio.TimeoutError:
log.info('Timed out when connecting for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
# keep reading the shard while others connect # keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self) self.shards[shard_id] = ret = Shard(ws, self)
ret.launch() ret.launch()

1
requirements.txt

@ -1,2 +1 @@
aiohttp>=3.6.0,<3.7.0 aiohttp>=3.6.0,<3.7.0
websockets>=6.0,!=7.0,!=8.0,!=8.0.1,<9.0

Loading…
Cancel
Save