Browse Source

Refactor voice websocket into gateway.py

pull/190/head
Rapptz 9 years ago
parent
commit
c1b5a52823
  1. 108
      discord/client.py
  2. 190
      discord/gateway.py
  3. 124
      discord/voice_client.py

108
discord/client.py

@ -38,7 +38,7 @@ from .errors import *
from .state import ConnectionState
from .permissions import Permissions
from . import utils, compat
from .enums import ChannelType, ServerRegion, Status
from .enums import ChannelType, ServerRegion
from .voice_client import VoiceClient
from .iterators import LogsFromIterator
from .gateway import *
@ -48,7 +48,7 @@ import aiohttp
import websockets
import logging, traceback
import sys, time, re, json
import sys, re
import tempfile, os, hashlib
import itertools
from random import randint as random_integer
@ -140,11 +140,6 @@ class Client:
self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop)
# These two events correspond to the two events necessary
# for a connection to be made
self._voice_data_found = asyncio.Event(loop=self.loop)
self._session_id_found = asyncio.Event(loop=self.loop)
# internals
def _get_cache_filename(self, email):
@ -280,72 +275,6 @@ class Client:
print('Ignoring exception in {}'.format(event_method), file=sys.stderr)
traceback.print_exc()
@asyncio.coroutine
def received_message(self, msg):
self.dispatch('socket_raw_receive', msg)
if isinstance(msg, bytes):
msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
msg = msg.decode('utf-8')
msg = json.loads(msg)
log.debug('WebSocket Event: {}'.format(msg))
self.dispatch('socket_response', msg)
op = msg.get('op')
data = msg.get('d')
if 's' in msg:
self.sequence = msg['s']
if op == 7:
# redirect op code
yield from self.ws.close()
yield from self.redirect_websocket(data.get('url'))
return
if op != 0:
log.info('Unhandled op {}'.format(op))
return
event = msg.get('t')
is_ready = event == 'READY'
if is_ready:
self.connection.clear()
self.session_id = data['session_id']
if is_ready or event == 'RESUMED':
interval = data['heartbeat_interval'] / 1000.0
self.keep_alive = compat.create_task(self.keep_alive_handler(interval), loop=self.loop)
if event == 'VOICE_STATE_UPDATE':
user_id = data.get('user_id')
if user_id == self.user.id:
if self.is_voice_connected():
self.voice.channel = self.get_channel(data.get('channel_id'))
self.session_id = data.get('session_id')
log.debug('Session ID found: {}'.format(self.session_id))
self._session_id_found.set()
if event == 'VOICE_SERVER_UPDATE':
self._voice_data_found.data = data
log.debug('Voice connection data found: {}'.format(data))
self._voice_data_found.set()
return
parser = 'parse_' + event.lower()
try:
func = getattr(self.connection, parser)
except AttributeError:
log.info('Unhandled event {}'.format(event))
else:
result = func(data)
# login state management
@asyncio.coroutine
@ -2442,7 +2371,6 @@ class Client:
:class:`VoiceClient`
A voice client that is fully connected to the voice server.
"""
if self.is_voice_connected():
raise ClientException('Already connected to a voice channel')
@ -2454,29 +2382,29 @@ class Client:
log.info('attempting to join voice channel {0.name}'.format(channel))
payload = {
'op': 4,
'd': {
'guild_id': channel.server.id,
'channel_id': channel.id,
'self_mute': False,
'self_deaf': False
}
}
def session_id_found(data):
user_id = data.get('user_id')
return user_id == self.user.id
yield from self._send_ws(utils.to_json(payload))
yield from asyncio.wait_for(self._session_id_found.wait(), timeout=5.0, loop=self.loop)
yield from asyncio.wait_for(self._voice_data_found.wait(), timeout=5.0, loop=self.loop)
# register the futures for waiting
session_id_future = self.ws.wait_for('VOICE_STATE_UPDATE', session_id_found)
voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True)
self._session_id_found.clear()
self._voice_data_found.clear()
# request joining
yield from self.ws.voice_state(channel.server.id, channel.id)
session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop)
data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop)
# todo: multivoice
if self.is_voice_connected():
self.voice.channel = self.get_channel(session_id_data.get('channel_id'))
kwargs = {
'user': self.user,
'channel': channel,
'data': self._voice_data_found.data,
'data': data,
'loop': self.loop,
'session_id': self.session_id,
'session_id': session_id_data.get('session_id'),
'main_ws': self.ws
}

190
discord/gateway.py

@ -36,11 +36,13 @@ import logging
import zlib, time, json
from collections import namedtuple
import threading
import struct
log = logging.getLogger(__name__)
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
'KeepAliveHandler' ]
'KeepAliveHandler', 'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket' ]
class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode."""
@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread):
self.ws = ws
self.interval = interval
self.daemon = True
self.msg = 'Keeping websocket alive with sequence {0[d]}'
self._stop = threading.Event()
def run(self):
while not self._stop.wait(self.interval):
data = self.get_payload()
msg = 'Keeping websocket alive with sequence {0[d]}'.format(data)
log.debug(msg)
log.debug(self.msg.format(data))
coro = self.ws.send_as_json(data)
f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread):
def stop(self):
self._stop.set()
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.msg = 'Keeping voice websocket alive with timestamp {0[d]}'
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000)
}
@asyncio.coroutine
def get_gateway(token, *, loop=None):
@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
connection=client.connection,
loop=client.loop)
def wait_for(self, event, predicate, result):
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
Parameters
@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
properties. The data parameter is the 'd' key in the JSON message.
result
A function that takes the same data parameter and executes to send
the result to the future.
the result to the future. If None, returns the data.
Returns
--------
@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# "reconnect" can only be handled by the Client
# so we terminate our connection and raise an
# internal exception signalling to reconnect.
log.info('Receivede RECONNECT opcode.')
yield from self.close()
raise ReconnectWebSocket()
@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
removed.append(index)
else:
if valid:
future.set_result(entry.result)
ret = data if entry.result is None else entry.result(data)
future.set_result(ret)
removed.append(index)
for index in reversed(removed):
@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from self.received_message(msg)
except websockets.exceptions.ConnectionClosed as e:
if e.code in (4008, 4009) or e.code in range(1001, 1015):
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
raise ReconnectWebSocket() from e
else:
raise ConnectionClosed(e) from e
@ -394,9 +410,171 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
status = Status.idle if idle_since else Status.online
me.status = status
@asyncio.coroutine
def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
payload = {
'op': self.VOICE_STATE,
'd': {
'guild_id': guild_id,
'channel_id': channel_id,
'self_mute': self_mute,
'self_deaf': self_deaf
}
}
yield from self.send_as_json(payload)
@asyncio.coroutine
def close(self, code=1000, reason=''):
if self._keep_alive:
self._keep_alive.stop()
yield from super().close(code, reason)
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
"""Implements the websocket protocol for handling voice connections.
Attributes
-----------
IDENTIFY
Send only. Starts a new voice session.
SELECT_PROTOCOL
Send only. Tells discord what encryption mode and how to connect for voice.
READY
Receive only. Tells the websocket that the initial connection has completed.
HEARTBEAT
Send only. Keeps your websocket connection alive.
SESSION_DESCRIPTION
Receive only. Gives you the secret key required for voice.
SPEAKING
Send only. Notifies the client if you are currently speaking.
"""
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = None
self._keep_alive = None
@asyncio.coroutine
def send_as_json(self, data):
yield from self.send(utils.to_json(data))
@classmethod
@asyncio.coroutine
def from_client(cls, client):
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
ws.gateway = gateway
ws._connection = client
identify = {
'op': cls.IDENTIFY,
'd': {
'server_id': client.guild_id,
'user_id': client.user.id,
'session_id': client.session_id,
'token': client.token
}
}
yield from ws.send_as_json(identify)
return ws
@asyncio.coroutine
def select_protocol(self, ip, port):
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
'protocol': 'udp',
'data': {
'address': ip,
'port': port,
'mode': 'xsalsa20_poly1305'
}
}
}
yield from self.send_as_json(payload)
log.debug('Selected protocol as {}'.format(payload))
@asyncio.coroutine
def speak(self, is_speaking=True):
payload = {
'op': self.SPEAKING,
'd': {
'speaking': is_speaking,
'delay': 0
}
}
yield from self.send_as_json(payload)
log.debug('Voice speaking now set to {}'.format(is_speaking))
@asyncio.coroutine
def received_message(self, msg):
log.debug('Voice websocket frame received: {}'.format(msg))
op = msg.get('op')
data = msg.get('d')
if op == self.READY:
interval = (data['heartbeat_interval'] / 100.0) - 5
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval)
self._keep_alive.start()
yield from self.initial_connection(data)
elif op == self.SESSION_DESCRIPTION:
yield from self.load_secret_key(data)
@asyncio.coroutine
def initial_connection(self, data):
state = self._connection
state.ssrc = data.get('ssrc')
state.voice_port = data.get('port')
packet = bytearray(70)
struct.pack_into('>I', packet, 0, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = yield from self.loop.sock_recv(state.socket, 70)
log.debug('received packet in initial_connection: {}'.format(recv))
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
# the port is a little endian unsigned short in the last two bytes
# yes, this is different endianness from everything else
state.port = struct.unpack_from('<H', recv, len(recv) - 2)[0]
log.debug('detected ip: {0.ip} port: {0.port}'.format(state))
yield from self.select_protocol(state.ip, state.port)
log.info('selected the voice protocol for use')
@asyncio.coroutine
def load_secret_key(self, data):
log.info('received secret key for voice connection')
self._connection.secret_key = data.get('secret_key')
yield from self.speak()
@asyncio.coroutine
def poll_event(self):
try:
msg = yield from self.recv()
yield from self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as e:
raise ConnectionClosed(e) from e
@asyncio.coroutine
def close(self, code=1000, reason=''):
if self._keep_alive:
self._keep_alive.stop()
yield from super().close(code, reason)

124
discord/voice_client.py

@ -55,6 +55,7 @@ import nacl.secret
log = logging.getLogger(__name__)
from . import utils
from .gateway import *
from .errors import ClientException, InvalidArgument
from .opus import Encoder as OpusEncoder
@ -173,7 +174,6 @@ class VoiceClient:
self.sequence = 0
self.timestamp = 0
self.encoder = OpusEncoder(48000, 2)
self.secret_key = []
log.info('created opus encoder with {0.__dict__}'.format(self.encoder))
def checked_add(self, attr, value, limit):
@ -183,87 +183,6 @@ class VoiceClient:
else:
setattr(self, attr, val + value)
@asyncio.coroutine
def keep_alive_handler(self, delay):
try:
while True:
payload = {
'op': 3,
'd': int(time.time())
}
msg = 'Keeping voice websocket alive with timestamp {}'
log.debug(msg.format(payload['d']))
yield from self.ws.send(utils.to_json(payload))
yield from asyncio.sleep(delay)
except asyncio.CancelledError:
pass
@asyncio.coroutine
def received_message(self, msg):
log.debug('Voice websocket frame received: {}'.format(msg))
op = msg.get('op')
data = msg.get('d')
if op == 2:
delay = (data['heartbeat_interval'] / 100.0) - 5
self.keep_alive = utils.create_task(self.keep_alive_handler(delay), loop=self.loop)
yield from self.initial_connection(data)
elif op == 4:
yield from self.connection_ready(data)
@asyncio.coroutine
def initial_connection(self, data):
self.ssrc = data.get('ssrc')
self.voice_port = data.get('port')
packet = bytearray(70)
struct.pack_into('>I', packet, 0, self.ssrc)
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
recv = yield from self.loop.sock_recv(self.socket, 70)
log.debug('received packet in initial_connection: {}'.format(recv))
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
self.ip = recv[ip_start:ip_end].decode('ascii')
# the port is a little endian unsigned short in the last two bytes
# yes, this is different endianness from everything else
self.port = struct.unpack_from('<H', recv, len(recv) - 2)[0]
log.debug('detected ip: {} port: {}'.format(self.ip, self.port))
payload = {
'op': 1,
'd': {
'protocol': 'udp',
'data': {
'address': self.ip,
'port': self.port,
'mode': 'xsalsa20_poly1305'
}
}
}
yield from self.ws.send(utils.to_json(payload))
log.debug('sent {} to initialize voice connection'.format(payload))
log.info('initial voice connection is done')
@asyncio.coroutine
def connection_ready(self, data):
log.info('voice connection is now ready')
self.secret_key = data.get('secret_key')
speaking = {
'op': 5,
'd': {
'speaking': True,
'delay': 0
}
}
yield from self.ws.send(utils.to_json(speaking))
self._connected.set()
# connection related
@asyncio.coroutine
@ -275,28 +194,15 @@ class VoiceClient:
self.socket.setblocking(False)
log.info('Voice endpoint found {0.endpoint} (IP: {0.endpoint_ip})'.format(self))
self.ws = yield from websockets.connect('wss://' + self.endpoint, loop=self.loop)
self.ws.max_size = None
payload = {
'op': 0,
'd': {
'server_id': self.guild_id,
'user_id': self.user.id,
'session_id': self.session_id,
'token': self.token
}
}
yield from self.ws.send(utils.to_json(payload))
self.ws = yield from DiscordVoiceWebSocket.from_client(self)
while not self._connected.is_set():
msg = yield from self.ws.recv()
if msg is None:
yield from self.disconnect()
raise ClientException('Unexpected websocket close on voice websocket')
yield from self.received_message(json.loads(msg))
yield from self.ws.poll_event()
if hasattr(self, 'secret_key'):
# we have a secret key, so we don't need to poll
# websocket events anymore
self._connected.set()
break
@asyncio.coroutine
def disconnect(self):
@ -310,22 +216,10 @@ class VoiceClient:
if not self._connected.is_set():
return
self.keep_alive.cancel()
self.socket.close()
self._connected.clear()
yield from self.ws.close()
payload = {
'op': 4,
'd': {
'guild_id': self.guild_id,
'channel_id': None,
'self_mute': True,
'self_deaf': False
}
}
yield from self.main_ws.send(utils.to_json(payload))
yield from self.main_ws.voice_state(self.guild_id, None, self_mute=True)
def is_connected(self):
"""bool : Indicates if the voice client is connected to voice."""

Loading…
Cancel
Save