|
|
@ -36,11 +36,13 @@ import logging |
|
|
|
import zlib, time, json |
|
|
|
from collections import namedtuple |
|
|
|
import threading |
|
|
|
import struct |
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', |
|
|
|
'KeepAliveHandler' ] |
|
|
|
'KeepAliveHandler', 'VoiceKeepAliveHandler', |
|
|
|
'DiscordVoiceWebSocket' ] |
|
|
|
|
|
|
|
class ReconnectWebSocket(Exception): |
|
|
|
"""Signals to handle the RECONNECT opcode.""" |
|
|
@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread): |
|
|
|
self.ws = ws |
|
|
|
self.interval = interval |
|
|
|
self.daemon = True |
|
|
|
self.msg = 'Keeping websocket alive with sequence {0[d]}' |
|
|
|
self._stop = threading.Event() |
|
|
|
|
|
|
|
def run(self): |
|
|
|
while not self._stop.wait(self.interval): |
|
|
|
data = self.get_payload() |
|
|
|
msg = 'Keeping websocket alive with sequence {0[d]}'.format(data) |
|
|
|
log.debug(msg) |
|
|
|
log.debug(self.msg.format(data)) |
|
|
|
coro = self.ws.send_as_json(data) |
|
|
|
f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop) |
|
|
|
try: |
|
|
@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread): |
|
|
|
def stop(self): |
|
|
|
self._stop.set() |
|
|
|
|
|
|
|
class VoiceKeepAliveHandler(KeepAliveHandler): |
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.msg = 'Keeping voice websocket alive with timestamp {0[d]}' |
|
|
|
|
|
|
|
def get_payload(self): |
|
|
|
return { |
|
|
|
'op': self.ws.HEARTBEAT, |
|
|
|
'd': int(time.time() * 1000) |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def get_gateway(token, *, loop=None): |
|
|
@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
connection=client.connection, |
|
|
|
loop=client.loop) |
|
|
|
|
|
|
|
def wait_for(self, event, predicate, result): |
|
|
|
def wait_for(self, event, predicate, result=None): |
|
|
|
"""Waits for a DISPATCH'd event that meets the predicate. |
|
|
|
|
|
|
|
Parameters |
|
|
@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
properties. The data parameter is the 'd' key in the JSON message. |
|
|
|
result |
|
|
|
A function that takes the same data parameter and executes to send |
|
|
|
the result to the future. |
|
|
|
the result to the future. If None, returns the data. |
|
|
|
|
|
|
|
Returns |
|
|
|
-------- |
|
|
@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
# "reconnect" can only be handled by the Client |
|
|
|
# so we terminate our connection and raise an |
|
|
|
# internal exception signalling to reconnect. |
|
|
|
log.info('Receivede RECONNECT opcode.') |
|
|
|
yield from self.close() |
|
|
|
raise ReconnectWebSocket() |
|
|
|
|
|
|
@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
removed.append(index) |
|
|
|
else: |
|
|
|
if valid: |
|
|
|
future.set_result(entry.result) |
|
|
|
ret = data if entry.result is None else entry.result(data) |
|
|
|
future.set_result(ret) |
|
|
|
removed.append(index) |
|
|
|
|
|
|
|
for index in reversed(removed): |
|
|
@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
yield from self.received_message(msg) |
|
|
|
except websockets.exceptions.ConnectionClosed as e: |
|
|
|
if e.code in (4008, 4009) or e.code in range(1001, 1015): |
|
|
|
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) |
|
|
|
raise ReconnectWebSocket() from e |
|
|
|
else: |
|
|
|
raise ConnectionClosed(e) from e |
|
|
@ -394,9 +410,171 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
status = Status.idle if idle_since else Status.online |
|
|
|
me.status = status |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): |
|
|
|
payload = { |
|
|
|
'op': self.VOICE_STATE, |
|
|
|
'd': { |
|
|
|
'guild_id': guild_id, |
|
|
|
'channel_id': channel_id, |
|
|
|
'self_mute': self_mute, |
|
|
|
'self_deaf': self_deaf |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
yield from self.send_as_json(payload) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def close(self, code=1000, reason=''): |
|
|
|
if self._keep_alive: |
|
|
|
self._keep_alive.stop() |
|
|
|
|
|
|
|
yield from super().close(code, reason) |
|
|
|
|
|
|
|
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): |
|
|
|
"""Implements the websocket protocol for handling voice connections. |
|
|
|
|
|
|
|
Attributes |
|
|
|
----------- |
|
|
|
IDENTIFY |
|
|
|
Send only. Starts a new voice session. |
|
|
|
SELECT_PROTOCOL |
|
|
|
Send only. Tells discord what encryption mode and how to connect for voice. |
|
|
|
READY |
|
|
|
Receive only. Tells the websocket that the initial connection has completed. |
|
|
|
HEARTBEAT |
|
|
|
Send only. Keeps your websocket connection alive. |
|
|
|
SESSION_DESCRIPTION |
|
|
|
Receive only. Gives you the secret key required for voice. |
|
|
|
SPEAKING |
|
|
|
Send only. Notifies the client if you are currently speaking. |
|
|
|
""" |
|
|
|
|
|
|
|
IDENTIFY = 0 |
|
|
|
SELECT_PROTOCOL = 1 |
|
|
|
READY = 2 |
|
|
|
HEARTBEAT = 3 |
|
|
|
SESSION_DESCRIPTION = 4 |
|
|
|
SPEAKING = 5 |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.max_size = None |
|
|
|
self._keep_alive = None |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def send_as_json(self, data): |
|
|
|
yield from self.send(utils.to_json(data)) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
@asyncio.coroutine |
|
|
|
def from_client(cls, client): |
|
|
|
"""Creates a voice websocket for the :class:`VoiceClient`.""" |
|
|
|
gateway = 'wss://' + client.endpoint |
|
|
|
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) |
|
|
|
ws.gateway = gateway |
|
|
|
ws._connection = client |
|
|
|
|
|
|
|
identify = { |
|
|
|
'op': cls.IDENTIFY, |
|
|
|
'd': { |
|
|
|
'server_id': client.guild_id, |
|
|
|
'user_id': client.user.id, |
|
|
|
'session_id': client.session_id, |
|
|
|
'token': client.token |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
yield from ws.send_as_json(identify) |
|
|
|
return ws |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def select_protocol(self, ip, port): |
|
|
|
payload = { |
|
|
|
'op': self.SELECT_PROTOCOL, |
|
|
|
'd': { |
|
|
|
'protocol': 'udp', |
|
|
|
'data': { |
|
|
|
'address': ip, |
|
|
|
'port': port, |
|
|
|
'mode': 'xsalsa20_poly1305' |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
yield from self.send_as_json(payload) |
|
|
|
log.debug('Selected protocol as {}'.format(payload)) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def speak(self, is_speaking=True): |
|
|
|
payload = { |
|
|
|
'op': self.SPEAKING, |
|
|
|
'd': { |
|
|
|
'speaking': is_speaking, |
|
|
|
'delay': 0 |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
yield from self.send_as_json(payload) |
|
|
|
log.debug('Voice speaking now set to {}'.format(is_speaking)) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def received_message(self, msg): |
|
|
|
log.debug('Voice websocket frame received: {}'.format(msg)) |
|
|
|
op = msg.get('op') |
|
|
|
data = msg.get('d') |
|
|
|
|
|
|
|
if op == self.READY: |
|
|
|
interval = (data['heartbeat_interval'] / 100.0) - 5 |
|
|
|
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval) |
|
|
|
self._keep_alive.start() |
|
|
|
yield from self.initial_connection(data) |
|
|
|
elif op == self.SESSION_DESCRIPTION: |
|
|
|
yield from self.load_secret_key(data) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def initial_connection(self, data): |
|
|
|
state = self._connection |
|
|
|
state.ssrc = data.get('ssrc') |
|
|
|
state.voice_port = data.get('port') |
|
|
|
packet = bytearray(70) |
|
|
|
struct.pack_into('>I', packet, 0, state.ssrc) |
|
|
|
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) |
|
|
|
recv = yield from self.loop.sock_recv(state.socket, 70) |
|
|
|
log.debug('received packet in initial_connection: {}'.format(recv)) |
|
|
|
|
|
|
|
# the ip is ascii starting at the 4th byte and ending at the first null |
|
|
|
ip_start = 4 |
|
|
|
ip_end = recv.index(0, ip_start) |
|
|
|
state.ip = recv[ip_start:ip_end].decode('ascii') |
|
|
|
|
|
|
|
# the port is a little endian unsigned short in the last two bytes |
|
|
|
# yes, this is different endianness from everything else |
|
|
|
state.port = struct.unpack_from('<H', recv, len(recv) - 2)[0] |
|
|
|
|
|
|
|
log.debug('detected ip: {0.ip} port: {0.port}'.format(state)) |
|
|
|
yield from self.select_protocol(state.ip, state.port) |
|
|
|
log.info('selected the voice protocol for use') |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def load_secret_key(self, data): |
|
|
|
log.info('received secret key for voice connection') |
|
|
|
self._connection.secret_key = data.get('secret_key') |
|
|
|
yield from self.speak() |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def poll_event(self): |
|
|
|
try: |
|
|
|
msg = yield from self.recv() |
|
|
|
yield from self.received_message(json.loads(msg)) |
|
|
|
except websockets.exceptions.ConnectionClosed as e: |
|
|
|
raise ConnectionClosed(e) from e |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def close(self, code=1000, reason=''): |
|
|
|
if self._keep_alive: |
|
|
|
self._keep_alive.stop() |
|
|
|
|
|
|
|
yield from super().close(code, reason) |
|
|
|
|
|
|
|
|
|
|
|