Browse Source

Refactor voice websocket into gateway.py

pull/190/head
Rapptz 9 years ago
parent
commit
c1b5a52823
  1. 108
      discord/client.py
  2. 190
      discord/gateway.py
  3. 124
      discord/voice_client.py

108
discord/client.py

@ -38,7 +38,7 @@ from .errors import *
from .state import ConnectionState from .state import ConnectionState
from .permissions import Permissions from .permissions import Permissions
from . import utils, compat from . import utils, compat
from .enums import ChannelType, ServerRegion, Status from .enums import ChannelType, ServerRegion
from .voice_client import VoiceClient from .voice_client import VoiceClient
from .iterators import LogsFromIterator from .iterators import LogsFromIterator
from .gateway import * from .gateway import *
@ -48,7 +48,7 @@ import aiohttp
import websockets import websockets
import logging, traceback import logging, traceback
import sys, time, re, json import sys, re
import tempfile, os, hashlib import tempfile, os, hashlib
import itertools import itertools
from random import randint as random_integer from random import randint as random_integer
@ -140,11 +140,6 @@ class Client:
self._is_logged_in = asyncio.Event(loop=self.loop) self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop) self._is_ready = asyncio.Event(loop=self.loop)
# These two events correspond to the two events necessary
# for a connection to be made
self._voice_data_found = asyncio.Event(loop=self.loop)
self._session_id_found = asyncio.Event(loop=self.loop)
# internals # internals
def _get_cache_filename(self, email): def _get_cache_filename(self, email):
@ -280,72 +275,6 @@ class Client:
print('Ignoring exception in {}'.format(event_method), file=sys.stderr) print('Ignoring exception in {}'.format(event_method), file=sys.stderr)
traceback.print_exc() traceback.print_exc()
@asyncio.coroutine
def received_message(self, msg):
self.dispatch('socket_raw_receive', msg)
if isinstance(msg, bytes):
msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
msg = msg.decode('utf-8')
msg = json.loads(msg)
log.debug('WebSocket Event: {}'.format(msg))
self.dispatch('socket_response', msg)
op = msg.get('op')
data = msg.get('d')
if 's' in msg:
self.sequence = msg['s']
if op == 7:
# redirect op code
yield from self.ws.close()
yield from self.redirect_websocket(data.get('url'))
return
if op != 0:
log.info('Unhandled op {}'.format(op))
return
event = msg.get('t')
is_ready = event == 'READY'
if is_ready:
self.connection.clear()
self.session_id = data['session_id']
if is_ready or event == 'RESUMED':
interval = data['heartbeat_interval'] / 1000.0
self.keep_alive = compat.create_task(self.keep_alive_handler(interval), loop=self.loop)
if event == 'VOICE_STATE_UPDATE':
user_id = data.get('user_id')
if user_id == self.user.id:
if self.is_voice_connected():
self.voice.channel = self.get_channel(data.get('channel_id'))
self.session_id = data.get('session_id')
log.debug('Session ID found: {}'.format(self.session_id))
self._session_id_found.set()
if event == 'VOICE_SERVER_UPDATE':
self._voice_data_found.data = data
log.debug('Voice connection data found: {}'.format(data))
self._voice_data_found.set()
return
parser = 'parse_' + event.lower()
try:
func = getattr(self.connection, parser)
except AttributeError:
log.info('Unhandled event {}'.format(event))
else:
result = func(data)
# login state management # login state management
@asyncio.coroutine @asyncio.coroutine
@ -2442,7 +2371,6 @@ class Client:
:class:`VoiceClient` :class:`VoiceClient`
A voice client that is fully connected to the voice server. A voice client that is fully connected to the voice server.
""" """
if self.is_voice_connected(): if self.is_voice_connected():
raise ClientException('Already connected to a voice channel') raise ClientException('Already connected to a voice channel')
@ -2454,29 +2382,29 @@ class Client:
log.info('attempting to join voice channel {0.name}'.format(channel)) log.info('attempting to join voice channel {0.name}'.format(channel))
payload = { def session_id_found(data):
'op': 4, user_id = data.get('user_id')
'd': { return user_id == self.user.id
'guild_id': channel.server.id,
'channel_id': channel.id,
'self_mute': False,
'self_deaf': False
}
}
yield from self._send_ws(utils.to_json(payload)) # register the futures for waiting
yield from asyncio.wait_for(self._session_id_found.wait(), timeout=5.0, loop=self.loop) session_id_future = self.ws.wait_for('VOICE_STATE_UPDATE', session_id_found)
yield from asyncio.wait_for(self._voice_data_found.wait(), timeout=5.0, loop=self.loop) voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True)
self._session_id_found.clear() # request joining
self._voice_data_found.clear() yield from self.ws.voice_state(channel.server.id, channel.id)
session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop)
data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop)
# todo: multivoice
if self.is_voice_connected():
self.voice.channel = self.get_channel(session_id_data.get('channel_id'))
kwargs = { kwargs = {
'user': self.user, 'user': self.user,
'channel': channel, 'channel': channel,
'data': self._voice_data_found.data, 'data': data,
'loop': self.loop, 'loop': self.loop,
'session_id': self.session_id, 'session_id': session_id_data.get('session_id'),
'main_ws': self.ws 'main_ws': self.ws
} }

190
discord/gateway.py

@ -36,11 +36,13 @@ import logging
import zlib, time, json import zlib, time, json
from collections import namedtuple from collections import namedtuple
import threading import threading
import struct
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', __all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
'KeepAliveHandler' ] 'KeepAliveHandler', 'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket' ]
class ReconnectWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode.""" """Signals to handle the RECONNECT opcode."""
@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread):
self.ws = ws self.ws = ws
self.interval = interval self.interval = interval
self.daemon = True self.daemon = True
self.msg = 'Keeping websocket alive with sequence {0[d]}'
self._stop = threading.Event() self._stop = threading.Event()
def run(self): def run(self):
while not self._stop.wait(self.interval): while not self._stop.wait(self.interval):
data = self.get_payload() data = self.get_payload()
msg = 'Keeping websocket alive with sequence {0[d]}'.format(data) log.debug(self.msg.format(data))
log.debug(msg)
coro = self.ws.send_as_json(data) coro = self.ws.send_as_json(data)
f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop) f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try: try:
@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread):
def stop(self): def stop(self):
self._stop.set() self._stop.set()
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.msg = 'Keeping voice websocket alive with timestamp {0[d]}'
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000)
}
@asyncio.coroutine @asyncio.coroutine
def get_gateway(token, *, loop=None): def get_gateway(token, *, loop=None):
@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
connection=client.connection, connection=client.connection,
loop=client.loop) loop=client.loop)
def wait_for(self, event, predicate, result): def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate. """Waits for a DISPATCH'd event that meets the predicate.
Parameters Parameters
@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
properties. The data parameter is the 'd' key in the JSON message. properties. The data parameter is the 'd' key in the JSON message.
result result
A function that takes the same data parameter and executes to send A function that takes the same data parameter and executes to send
the result to the future. the result to the future. If None, returns the data.
Returns Returns
-------- --------
@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# "reconnect" can only be handled by the Client # "reconnect" can only be handled by the Client
# so we terminate our connection and raise an # so we terminate our connection and raise an
# internal exception signalling to reconnect. # internal exception signalling to reconnect.
log.info('Receivede RECONNECT opcode.')
yield from self.close() yield from self.close()
raise ReconnectWebSocket() raise ReconnectWebSocket()
@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
removed.append(index) removed.append(index)
else: else:
if valid: if valid:
future.set_result(entry.result) ret = data if entry.result is None else entry.result(data)
future.set_result(ret)
removed.append(index) removed.append(index)
for index in reversed(removed): for index in reversed(removed):
@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from self.received_message(msg) yield from self.received_message(msg)
except websockets.exceptions.ConnectionClosed as e: except websockets.exceptions.ConnectionClosed as e:
if e.code in (4008, 4009) or e.code in range(1001, 1015): if e.code in (4008, 4009) or e.code in range(1001, 1015):
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
raise ReconnectWebSocket() from e raise ReconnectWebSocket() from e
else: else:
raise ConnectionClosed(e) from e raise ConnectionClosed(e) from e
@ -394,9 +410,171 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
status = Status.idle if idle_since else Status.online status = Status.idle if idle_since else Status.online
me.status = status me.status = status
@asyncio.coroutine
def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
payload = {
'op': self.VOICE_STATE,
'd': {
'guild_id': guild_id,
'channel_id': channel_id,
'self_mute': self_mute,
'self_deaf': self_deaf
}
}
yield from self.send_as_json(payload)
@asyncio.coroutine
def close(self, code=1000, reason=''):
if self._keep_alive:
self._keep_alive.stop()
yield from super().close(code, reason)
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
"""Implements the websocket protocol for handling voice connections.
Attributes
-----------
IDENTIFY
Send only. Starts a new voice session.
SELECT_PROTOCOL
Send only. Tells discord what encryption mode and how to connect for voice.
READY
Receive only. Tells the websocket that the initial connection has completed.
HEARTBEAT
Send only. Keeps your websocket connection alive.
SESSION_DESCRIPTION
Receive only. Gives you the secret key required for voice.
SPEAKING
Send only. Notifies the client if you are currently speaking.
"""
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = None
self._keep_alive = None
@asyncio.coroutine
def send_as_json(self, data):
yield from self.send(utils.to_json(data))
@classmethod
@asyncio.coroutine
def from_client(cls, client):
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
ws.gateway = gateway
ws._connection = client
identify = {
'op': cls.IDENTIFY,
'd': {
'server_id': client.guild_id,
'user_id': client.user.id,
'session_id': client.session_id,
'token': client.token
}
}
yield from ws.send_as_json(identify)
return ws
@asyncio.coroutine
def select_protocol(self, ip, port):
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
'protocol': 'udp',
'data': {
'address': ip,
'port': port,
'mode': 'xsalsa20_poly1305'
}
}
}
yield from self.send_as_json(payload)
log.debug('Selected protocol as {}'.format(payload))
@asyncio.coroutine
def speak(self, is_speaking=True):
payload = {
'op': self.SPEAKING,
'd': {
'speaking': is_speaking,
'delay': 0
}
}
yield from self.send_as_json(payload)
log.debug('Voice speaking now set to {}'.format(is_speaking))
@asyncio.coroutine
def received_message(self, msg):
log.debug('Voice websocket frame received: {}'.format(msg))
op = msg.get('op')
data = msg.get('d')
if op == self.READY:
interval = (data['heartbeat_interval'] / 100.0) - 5
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval)
self._keep_alive.start()
yield from self.initial_connection(data)
elif op == self.SESSION_DESCRIPTION:
yield from self.load_secret_key(data)
@asyncio.coroutine
def initial_connection(self, data):
state = self._connection
state.ssrc = data.get('ssrc')
state.voice_port = data.get('port')
packet = bytearray(70)
struct.pack_into('>I', packet, 0, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = yield from self.loop.sock_recv(state.socket, 70)
log.debug('received packet in initial_connection: {}'.format(recv))
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
# the port is a little endian unsigned short in the last two bytes
# yes, this is different endianness from everything else
state.port = struct.unpack_from('<H', recv, len(recv) - 2)[0]
log.debug('detected ip: {0.ip} port: {0.port}'.format(state))
yield from self.select_protocol(state.ip, state.port)
log.info('selected the voice protocol for use')
@asyncio.coroutine
def load_secret_key(self, data):
log.info('received secret key for voice connection')
self._connection.secret_key = data.get('secret_key')
yield from self.speak()
@asyncio.coroutine
def poll_event(self):
try:
msg = yield from self.recv()
yield from self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as e:
raise ConnectionClosed(e) from e
@asyncio.coroutine @asyncio.coroutine
def close(self, code=1000, reason=''): def close(self, code=1000, reason=''):
if self._keep_alive: if self._keep_alive:
self._keep_alive.stop() self._keep_alive.stop()
yield from super().close(code, reason) yield from super().close(code, reason)

124
discord/voice_client.py

@ -55,6 +55,7 @@ import nacl.secret
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from . import utils from . import utils
from .gateway import *
from .errors import ClientException, InvalidArgument from .errors import ClientException, InvalidArgument
from .opus import Encoder as OpusEncoder from .opus import Encoder as OpusEncoder
@ -173,7 +174,6 @@ class VoiceClient:
self.sequence = 0 self.sequence = 0
self.timestamp = 0 self.timestamp = 0
self.encoder = OpusEncoder(48000, 2) self.encoder = OpusEncoder(48000, 2)
self.secret_key = []
log.info('created opus encoder with {0.__dict__}'.format(self.encoder)) log.info('created opus encoder with {0.__dict__}'.format(self.encoder))
def checked_add(self, attr, value, limit): def checked_add(self, attr, value, limit):
@ -183,87 +183,6 @@ class VoiceClient:
else: else:
setattr(self, attr, val + value) setattr(self, attr, val + value)
@asyncio.coroutine
def keep_alive_handler(self, delay):
try:
while True:
payload = {
'op': 3,
'd': int(time.time())
}
msg = 'Keeping voice websocket alive with timestamp {}'
log.debug(msg.format(payload['d']))
yield from self.ws.send(utils.to_json(payload))
yield from asyncio.sleep(delay)
except asyncio.CancelledError:
pass
@asyncio.coroutine
def received_message(self, msg):
log.debug('Voice websocket frame received: {}'.format(msg))
op = msg.get('op')
data = msg.get('d')
if op == 2:
delay = (data['heartbeat_interval'] / 100.0) - 5
self.keep_alive = utils.create_task(self.keep_alive_handler(delay), loop=self.loop)
yield from self.initial_connection(data)
elif op == 4:
yield from self.connection_ready(data)
@asyncio.coroutine
def initial_connection(self, data):
self.ssrc = data.get('ssrc')
self.voice_port = data.get('port')
packet = bytearray(70)
struct.pack_into('>I', packet, 0, self.ssrc)
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
recv = yield from self.loop.sock_recv(self.socket, 70)
log.debug('received packet in initial_connection: {}'.format(recv))
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
self.ip = recv[ip_start:ip_end].decode('ascii')
# the port is a little endian unsigned short in the last two bytes
# yes, this is different endianness from everything else
self.port = struct.unpack_from('<H', recv, len(recv) - 2)[0]
log.debug('detected ip: {} port: {}'.format(self.ip, self.port))
payload = {
'op': 1,
'd': {
'protocol': 'udp',
'data': {
'address': self.ip,
'port': self.port,
'mode': 'xsalsa20_poly1305'
}
}
}
yield from self.ws.send(utils.to_json(payload))
log.debug('sent {} to initialize voice connection'.format(payload))
log.info('initial voice connection is done')
@asyncio.coroutine
def connection_ready(self, data):
log.info('voice connection is now ready')
self.secret_key = data.get('secret_key')
speaking = {
'op': 5,
'd': {
'speaking': True,
'delay': 0
}
}
yield from self.ws.send(utils.to_json(speaking))
self._connected.set()
# connection related # connection related
@asyncio.coroutine @asyncio.coroutine
@ -275,28 +194,15 @@ class VoiceClient:
self.socket.setblocking(False) self.socket.setblocking(False)
log.info('Voice endpoint found {0.endpoint} (IP: {0.endpoint_ip})'.format(self)) log.info('Voice endpoint found {0.endpoint} (IP: {0.endpoint_ip})'.format(self))
self.ws = yield from websockets.connect('wss://' + self.endpoint, loop=self.loop)
self.ws.max_size = None
payload = {
'op': 0,
'd': {
'server_id': self.guild_id,
'user_id': self.user.id,
'session_id': self.session_id,
'token': self.token
}
}
yield from self.ws.send(utils.to_json(payload))
self.ws = yield from DiscordVoiceWebSocket.from_client(self)
while not self._connected.is_set(): while not self._connected.is_set():
msg = yield from self.ws.recv() yield from self.ws.poll_event()
if msg is None: if hasattr(self, 'secret_key'):
yield from self.disconnect() # we have a secret key, so we don't need to poll
raise ClientException('Unexpected websocket close on voice websocket') # websocket events anymore
self._connected.set()
yield from self.received_message(json.loads(msg)) break
@asyncio.coroutine @asyncio.coroutine
def disconnect(self): def disconnect(self):
@ -310,22 +216,10 @@ class VoiceClient:
if not self._connected.is_set(): if not self._connected.is_set():
return return
self.keep_alive.cancel()
self.socket.close() self.socket.close()
self._connected.clear() self._connected.clear()
yield from self.ws.close() yield from self.ws.close()
yield from self.main_ws.voice_state(self.guild_id, None, self_mute=True)
payload = {
'op': 4,
'd': {
'guild_id': self.guild_id,
'channel_id': None,
'self_mute': True,
'self_deaf': False
}
}
yield from self.main_ws.send(utils.to_json(payload))
def is_connected(self): def is_connected(self):
"""bool : Indicates if the voice client is connected to voice.""" """bool : Indicates if the voice client is connected to voice."""

Loading…
Cancel
Save