From c1b5a528230765ce67fe885e0fd50058e0a820c9 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 27 Apr 2016 17:37:25 -0400 Subject: [PATCH] Refactor voice websocket into gateway.py --- discord/client.py | 108 ++++------------------- discord/gateway.py | 190 ++++++++++++++++++++++++++++++++++++++-- discord/voice_client.py | 124 ++------------------------ 3 files changed, 211 insertions(+), 211 deletions(-) diff --git a/discord/client.py b/discord/client.py index d8dd9b12d..aaee8d647 100644 --- a/discord/client.py +++ b/discord/client.py @@ -38,7 +38,7 @@ from .errors import * from .state import ConnectionState from .permissions import Permissions from . import utils, compat -from .enums import ChannelType, ServerRegion, Status +from .enums import ChannelType, ServerRegion from .voice_client import VoiceClient from .iterators import LogsFromIterator from .gateway import * @@ -48,7 +48,7 @@ import aiohttp import websockets import logging, traceback -import sys, time, re, json +import sys, re import tempfile, os, hashlib import itertools from random import randint as random_integer @@ -140,11 +140,6 @@ class Client: self._is_logged_in = asyncio.Event(loop=self.loop) self._is_ready = asyncio.Event(loop=self.loop) - # These two events correspond to the two events necessary - # for a connection to be made - self._voice_data_found = asyncio.Event(loop=self.loop) - self._session_id_found = asyncio.Event(loop=self.loop) - # internals def _get_cache_filename(self, email): @@ -280,72 +275,6 @@ class Client: print('Ignoring exception in {}'.format(event_method), file=sys.stderr) traceback.print_exc() - @asyncio.coroutine - def received_message(self, msg): - self.dispatch('socket_raw_receive', msg) - - if isinstance(msg, bytes): - msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB - msg = msg.decode('utf-8') - - msg = json.loads(msg) - - log.debug('WebSocket Event: {}'.format(msg)) - self.dispatch('socket_response', msg) - - op = msg.get('op') - data = msg.get('d') - - if 's' in msg: - self.sequence = msg['s'] - - if op == 7: - # redirect op code - yield from self.ws.close() - yield from self.redirect_websocket(data.get('url')) - return - - if op != 0: - log.info('Unhandled op {}'.format(op)) - return - - event = msg.get('t') - is_ready = event == 'READY' - - if is_ready: - self.connection.clear() - self.session_id = data['session_id'] - - if is_ready or event == 'RESUMED': - interval = data['heartbeat_interval'] / 1000.0 - self.keep_alive = compat.create_task(self.keep_alive_handler(interval), loop=self.loop) - - if event == 'VOICE_STATE_UPDATE': - user_id = data.get('user_id') - if user_id == self.user.id: - if self.is_voice_connected(): - self.voice.channel = self.get_channel(data.get('channel_id')) - - self.session_id = data.get('session_id') - log.debug('Session ID found: {}'.format(self.session_id)) - self._session_id_found.set() - - - if event == 'VOICE_SERVER_UPDATE': - self._voice_data_found.data = data - log.debug('Voice connection data found: {}'.format(data)) - self._voice_data_found.set() - return - - parser = 'parse_' + event.lower() - - try: - func = getattr(self.connection, parser) - except AttributeError: - log.info('Unhandled event {}'.format(event)) - else: - result = func(data) - # login state management @asyncio.coroutine @@ -2442,7 +2371,6 @@ class Client: :class:`VoiceClient` A voice client that is fully connected to the voice server. """ - if self.is_voice_connected(): raise ClientException('Already connected to a voice channel') @@ -2454,29 +2382,29 @@ class Client: log.info('attempting to join voice channel {0.name}'.format(channel)) - payload = { - 'op': 4, - 'd': { - 'guild_id': channel.server.id, - 'channel_id': channel.id, - 'self_mute': False, - 'self_deaf': False - } - } + def session_id_found(data): + user_id = data.get('user_id') + return user_id == self.user.id - yield from self._send_ws(utils.to_json(payload)) - yield from asyncio.wait_for(self._session_id_found.wait(), timeout=5.0, loop=self.loop) - yield from asyncio.wait_for(self._voice_data_found.wait(), timeout=5.0, loop=self.loop) + # register the futures for waiting + session_id_future = self.ws.wait_for('VOICE_STATE_UPDATE', session_id_found) + voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True) - self._session_id_found.clear() - self._voice_data_found.clear() + # request joining + yield from self.ws.voice_state(channel.server.id, channel.id) + session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop) + data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop) + + # todo: multivoice + if self.is_voice_connected(): + self.voice.channel = self.get_channel(session_id_data.get('channel_id')) kwargs = { 'user': self.user, 'channel': channel, - 'data': self._voice_data_found.data, + 'data': data, 'loop': self.loop, - 'session_id': self.session_id, + 'session_id': session_id_data.get('session_id'), 'main_ws': self.ws } diff --git a/discord/gateway.py b/discord/gateway.py index 2b4fc4dc4..ccfc98df9 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -36,11 +36,13 @@ import logging import zlib, time, json from collections import namedtuple import threading +import struct log = logging.getLogger(__name__) __all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', - 'KeepAliveHandler' ] + 'KeepAliveHandler', 'VoiceKeepAliveHandler', + 'DiscordVoiceWebSocket' ] class ReconnectWebSocket(Exception): """Signals to handle the RECONNECT opcode.""" @@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread): self.ws = ws self.interval = interval self.daemon = True + self.msg = 'Keeping websocket alive with sequence {0[d]}' self._stop = threading.Event() def run(self): while not self._stop.wait(self.interval): data = self.get_payload() - msg = 'Keeping websocket alive with sequence {0[d]}'.format(data) - log.debug(msg) + log.debug(self.msg.format(data)) coro = self.ws.send_as_json(data) f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: @@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread): def stop(self): self._stop.set() +class VoiceKeepAliveHandler(KeepAliveHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.msg = 'Keeping voice websocket alive with timestamp {0[d]}' + + def get_payload(self): + return { + 'op': self.ws.HEARTBEAT, + 'd': int(time.time() * 1000) + } + @asyncio.coroutine def get_gateway(token, *, loop=None): @@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): connection=client.connection, loop=client.loop) - def wait_for(self, event, predicate, result): + def wait_for(self, event, predicate, result=None): """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): properties. The data parameter is the 'd' key in the JSON message. result A function that takes the same data parameter and executes to send - the result to the future. + the result to the future. If None, returns the data. Returns -------- @@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): # "reconnect" can only be handled by the Client # so we terminate our connection and raise an # internal exception signalling to reconnect. + log.info('Receivede RECONNECT opcode.') yield from self.close() raise ReconnectWebSocket() @@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): removed.append(index) else: if valid: - future.set_result(entry.result) + ret = data if entry.result is None else entry.result(data) + future.set_result(ret) removed.append(index) for index in reversed(removed): @@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): yield from self.received_message(msg) except websockets.exceptions.ConnectionClosed as e: if e.code in (4008, 4009) or e.code in range(1001, 1015): + log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) raise ReconnectWebSocket() from e else: raise ConnectionClosed(e) from e @@ -394,9 +410,171 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): status = Status.idle if idle_since else Status.online me.status = status + @asyncio.coroutine + def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + payload = { + 'op': self.VOICE_STATE, + 'd': { + 'guild_id': guild_id, + 'channel_id': channel_id, + 'self_mute': self_mute, + 'self_deaf': self_deaf + } + } + + yield from self.send_as_json(payload) + + @asyncio.coroutine + def close(self, code=1000, reason=''): + if self._keep_alive: + self._keep_alive.stop() + + yield from super().close(code, reason) + +class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): + """Implements the websocket protocol for handling voice connections. + + Attributes + ----------- + IDENTIFY + Send only. Starts a new voice session. + SELECT_PROTOCOL + Send only. Tells discord what encryption mode and how to connect for voice. + READY + Receive only. Tells the websocket that the initial connection has completed. + HEARTBEAT + Send only. Keeps your websocket connection alive. + SESSION_DESCRIPTION + Receive only. Gives you the secret key required for voice. + SPEAKING + Send only. Notifies the client if you are currently speaking. + """ + + IDENTIFY = 0 + SELECT_PROTOCOL = 1 + READY = 2 + HEARTBEAT = 3 + SESSION_DESCRIPTION = 4 + SPEAKING = 5 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = None + self._keep_alive = None + + @asyncio.coroutine + def send_as_json(self, data): + yield from self.send(utils.to_json(data)) + + @classmethod + @asyncio.coroutine + def from_client(cls, client): + """Creates a voice websocket for the :class:`VoiceClient`.""" + gateway = 'wss://' + client.endpoint + ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) + ws.gateway = gateway + ws._connection = client + + identify = { + 'op': cls.IDENTIFY, + 'd': { + 'server_id': client.guild_id, + 'user_id': client.user.id, + 'session_id': client.session_id, + 'token': client.token + } + } + + yield from ws.send_as_json(identify) + return ws + + @asyncio.coroutine + def select_protocol(self, ip, port): + payload = { + 'op': self.SELECT_PROTOCOL, + 'd': { + 'protocol': 'udp', + 'data': { + 'address': ip, + 'port': port, + 'mode': 'xsalsa20_poly1305' + } + } + } + + yield from self.send_as_json(payload) + log.debug('Selected protocol as {}'.format(payload)) + + @asyncio.coroutine + def speak(self, is_speaking=True): + payload = { + 'op': self.SPEAKING, + 'd': { + 'speaking': is_speaking, + 'delay': 0 + } + } + + yield from self.send_as_json(payload) + log.debug('Voice speaking now set to {}'.format(is_speaking)) + + @asyncio.coroutine + def received_message(self, msg): + log.debug('Voice websocket frame received: {}'.format(msg)) + op = msg.get('op') + data = msg.get('d') + + if op == self.READY: + interval = (data['heartbeat_interval'] / 100.0) - 5 + self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval) + self._keep_alive.start() + yield from self.initial_connection(data) + elif op == self.SESSION_DESCRIPTION: + yield from self.load_secret_key(data) + + @asyncio.coroutine + def initial_connection(self, data): + state = self._connection + state.ssrc = data.get('ssrc') + state.voice_port = data.get('port') + packet = bytearray(70) + struct.pack_into('>I', packet, 0, state.ssrc) + state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) + recv = yield from self.loop.sock_recv(state.socket, 70) + log.debug('received packet in initial_connection: {}'.format(recv)) + + # the ip is ascii starting at the 4th byte and ending at the first null + ip_start = 4 + ip_end = recv.index(0, ip_start) + state.ip = recv[ip_start:ip_end].decode('ascii') + + # the port is a little endian unsigned short in the last two bytes + # yes, this is different endianness from everything else + state.port = struct.unpack_from('I', packet, 0, self.ssrc) - self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) - recv = yield from self.loop.sock_recv(self.socket, 70) - log.debug('received packet in initial_connection: {}'.format(recv)) - - # the ip is ascii starting at the 4th byte and ending at the first null - ip_start = 4 - ip_end = recv.index(0, ip_start) - self.ip = recv[ip_start:ip_end].decode('ascii') - - # the port is a little endian unsigned short in the last two bytes - # yes, this is different endianness from everything else - self.port = struct.unpack_from('