|
|
@ -36,7 +36,7 @@ import threading |
|
|
|
import traceback |
|
|
|
import zlib |
|
|
|
|
|
|
|
import websockets |
|
|
|
import aiohttp |
|
|
|
|
|
|
|
from . import utils |
|
|
|
from .activity import BaseActivity |
|
|
@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception): |
|
|
|
self.resume = resume |
|
|
|
self.op = 'RESUME' if resume else 'IDENTIFY' |
|
|
|
|
|
|
|
class WebSocketClosure(Exception): |
|
|
|
"""An exception to make up for the fact that aiohttp doesn't signal closure.""" |
|
|
|
pass |
|
|
|
|
|
|
|
EventListener = namedtuple('EventListener', 'predicate event result future') |
|
|
|
|
|
|
|
class KeepAliveHandler(threading.Thread): |
|
|
@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler): |
|
|
|
self.latency = ack_time - self._last_send |
|
|
|
self.recent_ack_latencies.append(self.latency) |
|
|
|
|
|
|
|
class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
"""Implements a WebSocket for Discord's gateway v6. |
|
|
|
# Monkey patch certain things from the aiohttp websocket code |
|
|
|
# Check this whenever we update dependencies. |
|
|
|
OLD_CLOSE = aiohttp.ClientWebSocketResponse.close |
|
|
|
|
|
|
|
async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool: |
|
|
|
return await OLD_CLOSE(self, code=code, message=message) |
|
|
|
|
|
|
|
This is created through :func:`create_main_websocket`. Library |
|
|
|
users should never create this manually. |
|
|
|
aiohttp.ClientWebSocketResponse.close = _new_ws_close |
|
|
|
|
|
|
|
class DiscordWebSocket: |
|
|
|
"""Implements a WebSocket for Discord's gateway v6. |
|
|
|
|
|
|
|
Attributes |
|
|
|
----------- |
|
|
@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
HEARTBEAT_ACK = 11 |
|
|
|
GUILD_SYNC = 12 |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.max_size = None |
|
|
|
def __init__(self, socket, *, loop): |
|
|
|
self.socket = socket |
|
|
|
self.loop = loop |
|
|
|
|
|
|
|
# an empty dispatcher to prevent crashes |
|
|
|
self._dispatch = lambda *args: None |
|
|
|
# generic event listeners |
|
|
@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
self._zlib = zlib.decompressobj() |
|
|
|
self._buffer = bytearray() |
|
|
|
|
|
|
|
@property |
|
|
|
def open(self): |
|
|
|
return not self.socket.closed |
|
|
|
|
|
|
|
@classmethod |
|
|
|
async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False): |
|
|
|
async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False): |
|
|
|
"""Creates a main websocket for Discord from a :class:`Client`. |
|
|
|
|
|
|
|
This is for internal use only. |
|
|
|
""" |
|
|
|
gateway = await client.http.get_gateway() |
|
|
|
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) |
|
|
|
gateway = gateway or await client.http.get_gateway() |
|
|
|
socket = await client.http.ws_connect(gateway) |
|
|
|
ws = cls(socket, loop=client.loop) |
|
|
|
|
|
|
|
# dynamically add attributes needed |
|
|
|
ws.token = client.http.token |
|
|
@ -267,14 +283,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
return ws |
|
|
|
|
|
|
|
await ws.resume() |
|
|
|
try: |
|
|
|
await ws.ensure_open() |
|
|
|
except websockets.exceptions.ConnectionClosed: |
|
|
|
# ws got closed so let's just do a regular IDENTIFY connect. |
|
|
|
log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id) |
|
|
|
return await cls.from_client(client, shard_id=shard_id) |
|
|
|
else: |
|
|
|
return ws |
|
|
|
return ws |
|
|
|
|
|
|
|
def wait_for(self, event, predicate, result=None): |
|
|
|
"""Waits for a DISPATCH'd event that meets the predicate. |
|
|
@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
heartbeat = self._keep_alive |
|
|
|
return float('inf') if heartbeat is None else heartbeat.latency |
|
|
|
|
|
|
|
def _can_handle_close(self, code): |
|
|
|
return code not in (1000, 4004, 4010, 4011) |
|
|
|
def _can_handle_close(self): |
|
|
|
return self.socket.close_code not in (1000, 4004, 4010, 4011) |
|
|
|
|
|
|
|
async def poll_event(self): |
|
|
|
"""Polls for a DISPATCH event and handles the general gateway loop. |
|
|
@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
The websocket connection was terminated for unhandled reasons. |
|
|
|
""" |
|
|
|
try: |
|
|
|
msg = await self.recv() |
|
|
|
await self.received_message(msg) |
|
|
|
except websockets.exceptions.ConnectionClosed as exc: |
|
|
|
if self._can_handle_close(exc.code): |
|
|
|
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason) |
|
|
|
raise ReconnectWebSocket(self.shard_id) from exc |
|
|
|
else: |
|
|
|
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason) |
|
|
|
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc |
|
|
|
msg = await self.socket.receive() |
|
|
|
if msg.type is aiohttp.WSMsgType.TEXT: |
|
|
|
await self.received_message(msg.data) |
|
|
|
elif msg.type is aiohttp.WSMsgType.BINARY: |
|
|
|
await self.received_message(msg.data) |
|
|
|
elif msg.type is aiohttp.WSMsgType.ERROR: |
|
|
|
log.debug('Received %s', msg) |
|
|
|
raise msg.data |
|
|
|
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): |
|
|
|
log.debug('Received %s', msg) |
|
|
|
raise WebSocketClosure('Unexpected WebSocket closure.') |
|
|
|
except WebSocketClosure as e: |
|
|
|
if self._can_handle_close(): |
|
|
|
log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code) |
|
|
|
raise ReconnectWebSocket(self.shard_id) from e |
|
|
|
elif self.socket.close_code is not None: |
|
|
|
log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code) |
|
|
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e |
|
|
|
|
|
|
|
async def send(self, data): |
|
|
|
self._dispatch('socket_raw_send', data) |
|
|
|
await super().send(data) |
|
|
|
await self.socket.send_str(data) |
|
|
|
|
|
|
|
async def send_as_json(self, data): |
|
|
|
try: |
|
|
|
await self.send(utils.to_json(data)) |
|
|
|
except websockets.exceptions.ConnectionClosed as exc: |
|
|
|
if not self._can_handle_close(exc.code): |
|
|
|
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc |
|
|
|
except RuntimeError as exc: |
|
|
|
if not self._can_handle_close(): |
|
|
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc |
|
|
|
|
|
|
|
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): |
|
|
|
if activity is not None: |
|
|
@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
log.debug('Updating our voice state to %s.', payload) |
|
|
|
await self.send_as_json(payload) |
|
|
|
|
|
|
|
async def close(self, code=4000, reason=''): |
|
|
|
if self._keep_alive: |
|
|
|
self._keep_alive.stop() |
|
|
|
|
|
|
|
await super().close(code, reason) |
|
|
|
|
|
|
|
async def close_connection(self, *args, **kwargs): |
|
|
|
async def close(self, code=4000): |
|
|
|
if self._keep_alive: |
|
|
|
self._keep_alive.stop() |
|
|
|
|
|
|
|
await super().close_connection(*args, **kwargs) |
|
|
|
await self.socket.close(code=code) |
|
|
|
|
|
|
|
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
class DiscordVoiceWebSocket: |
|
|
|
"""Implements the websocket protocol for handling voice connections. |
|
|
|
|
|
|
|
Attributes |
|
|
@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
CLIENT_CONNECT = 12 |
|
|
|
CLIENT_DISCONNECT = 13 |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.max_size = None |
|
|
|
def __init__(self, socket): |
|
|
|
self.ws = socket |
|
|
|
self._keep_alive = None |
|
|
|
|
|
|
|
async def send_as_json(self, data): |
|
|
|
log.debug('Sending voice websocket frame: %s.', data) |
|
|
|
await self.send(utils.to_json(data)) |
|
|
|
await self.ws.send_str(utils.to_json(data)) |
|
|
|
|
|
|
|
async def resume(self): |
|
|
|
state = self._connection |
|
|
@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
async def from_client(cls, client, *, resume=False): |
|
|
|
"""Creates a voice websocket for the :class:`VoiceClient`.""" |
|
|
|
gateway = 'wss://' + client.endpoint + '/?v=4' |
|
|
|
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) |
|
|
|
http = client._state.http |
|
|
|
socket = await http.ws_connect(gateway) |
|
|
|
ws = cls(socket) |
|
|
|
ws.gateway = gateway |
|
|
|
ws._connection = client |
|
|
|
ws._max_heartbeat_timeout = 60.0 |
|
|
@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
await self.speak(False) |
|
|
|
|
|
|
|
async def poll_event(self): |
|
|
|
try: |
|
|
|
msg = await asyncio.wait_for(self.recv(), timeout=30.0) |
|
|
|
await self.received_message(json.loads(msg)) |
|
|
|
except websockets.exceptions.ConnectionClosed as exc: |
|
|
|
raise ConnectionClosed(exc, shard_id=None) from exc |
|
|
|
|
|
|
|
async def close_connection(self, *args, **kwargs): |
|
|
|
if self._keep_alive: |
|
|
|
# This exception is handled up the chain |
|
|
|
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) |
|
|
|
if msg.type is aiohttp.WSMsgType.TEXT: |
|
|
|
await self.received_message(json.loads(msg.data)) |
|
|
|
elif msg.type is aiohttp.WSMsgType.ERROR: |
|
|
|
log.debug('Received %s', msg) |
|
|
|
raise ConnectionClosed(self.ws, shard_id=None) from msg.data |
|
|
|
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): |
|
|
|
log.debug('Received %s', msg) |
|
|
|
raise ConnectionClosed(self.ws, shard_id=None) |
|
|
|
|
|
|
|
async def close(self, code=1000): |
|
|
|
if self._keep_alive is not None: |
|
|
|
self._keep_alive.stop() |
|
|
|
|
|
|
|
await super().close_connection(*args, **kwargs) |
|
|
|
await self.ws.close(code=code) |
|
|
|