diff --git a/discord/client.py b/discord/client.py index 09165a77c..5d1729fba 100644 --- a/discord/client.py +++ b/discord/client.py @@ -35,6 +35,7 @@ from .errors import * from .state import ConnectionState from . import utils from .enums import ChannelType, ServerRegion +from .voice_client import VoiceClient import asyncio import aiohttp @@ -47,9 +48,6 @@ log = logging.getLogger(__name__) request_logging_format = '{method} {response.url} has returned {response.status}' request_success_log = '{response.url} with {json} received {data}' -def to_json(obj): - return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) - class Client: """Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -108,6 +106,15 @@ class Client: self._closed = False self._is_logged_in = False + # this is shared state between Client and VoiceClient + # could this lead to issues? Not sure. I want to say no. + self._is_voice_connected = asyncio.Event(loop=self.loop) + + # These two events correspond to the two events necessary + # for a connection to be made + self._voice_data_found = asyncio.Event(loop=self.loop) + self._session_id_found = asyncio.Event(loop=self.loop) + # internals def _resolve_mentions(self, content, mentions): @@ -195,7 +202,7 @@ class Client: msg = 'Keeping websocket alive with timestamp {}' log.debug(msg.format(payload['d'])) - yield from self.ws.send(to_json(payload)) + yield from self.ws.send(utils.to_json(payload)) yield from asyncio.sleep(interval) @asyncio.coroutine @@ -228,6 +235,16 @@ class Client: interval = data['heartbeat_interval'] / 1000.0 self.keep_alive = utils.create_task(self.keep_alive_handler(interval), loop=self.loop) + if event == 'VOICE_STATE_UPDATE': + user_id = data.get('user_id') + if user_id == self.user.id: + self.session_id = data.get('session_id') + self._session_id_found.set() + + if event == 'VOICE_SERVER_UPDATE': + self._voice_data_found.data = data + self._voice_data_found.set() + if event in ('READY', 'MESSAGE_CREATE', 'MESSAGE_DELETE', 'MESSAGE_UPDATE', 'PRESENCE_UPDATE', 'USER_UPDATE', 'CHANNEL_DELETE', 'CHANNEL_UPDATE', 'CHANNEL_CREATE', @@ -264,7 +281,7 @@ class Client: } } - yield from self.ws.send(to_json(payload)) + yield from self.ws.send(utils.to_json(payload)) log.info('sent the initial payload to create the websocket') # properties @@ -348,7 +365,7 @@ class Client: 'password': password } - data = to_json(payload) + data = utils.to_json(payload) resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers) log.debug(request_logging_format.format(method='POST', response=resp)) if resp.status == 400: @@ -513,7 +530,7 @@ class Client: } url = '{}/@me/channels'.format(endpoints.USERS) - r = yield from self.session.post(url, data=to_json(payload), headers=self.headers) + r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) log.debug(request_logging_format.format(method='POST', response=r)) yield from utils._verify_successful_response(r) data = yield from r.json() @@ -584,7 +601,7 @@ class Client: if tts: payload['tts'] = True - resp = yield from self.session.post(url, data=to_json(payload), headers=self.headers) + resp = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) log.debug(request_logging_format.format(method='POST', response=resp)) yield from utils._verify_successful_response(resp) data = yield from resp.json() @@ -755,7 +772,7 @@ class Client: 'mentions': self._resolve_mentions(content, mentions) } - response = yield from self.session.patch(url, headers=self.headers, data=to_json(payload)) + response = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=response)) yield from utils._verify_successful_response(response) data = yield from response.json() @@ -952,7 +969,7 @@ class Client: 'deaf': deafen } - response = yield from self.session.patch(url, headers=self.headers, data=to_json(payload)) + response = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=response)) yield from utils._verify_successful_response(response) @@ -1016,7 +1033,7 @@ class Client: } url = '{0}/@me'.format(endpoints.USERS) - r = yield from self.session.patch(url, headers=self.headers, data=to_json(payload)) + r = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=r)) yield from utils._verify_successful_response(r) @@ -1069,7 +1086,7 @@ class Client: } } - sent = to_json(payload) + sent = utils.to_json(payload) log.debug('Sending "{}" to change status'.format(sent)) yield from self.ws.send(sent) @@ -1111,7 +1128,7 @@ class Client: 'position': options.get('position', channel.position) } - r = yield from self.session.patch(url, headers=self.headers, data=to_json(payload)) + r = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=r)) yield from utils._verify_successful_response(r) @@ -1160,7 +1177,7 @@ class Client: } url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server) - response = yield from self.session.post(url, headers=self.headers, data=to_json(payload)) + response = yield from self.session.post(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='POST', response=response)) yield from utils._verify_successful_response(response) @@ -1329,7 +1346,7 @@ class Client: payload['afk_channel'] = getattr(afk_channel, 'id', None) url = '{0}/{1.id}'.format(endpoints.SERVERS, server) - r = yield from self.session.patch(url, headers=self.headers, data=to_json(payload)) + r = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=r)) yield from utils._verify_successful_response(r) @@ -1377,7 +1394,7 @@ class Client: } url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination) - response = yield from self.session.post(url, headers=self.headers, data=to_json(payload)) + response = yield from self.session.post(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='POST', response=response)) yield from utils._verify_successful_response(response) @@ -1497,7 +1514,6 @@ class Client: log.debug(request_logging_format.format(method='DELETE', response=response)) yield from utils._verify_successful_response(response) - # Role management @asyncio.coroutine @@ -1555,7 +1571,7 @@ class Client: 'hoist': fields.get('hoist', role.hoist) } - r = yield from self.session.patch(url, data=to_json(payload), headers=self.headers) + r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) log.debug(request_logging_format.format(method='PATCH', response=r)) yield from utils._verify_successful_response(r) @@ -1683,7 +1699,7 @@ class Client: 'roles': [role.id for role in roles] } - r = yield from self.session.patch(url, headers=self.headers, data=to_json(payload)) + r = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=r)) yield from utils._verify_successful_response(r) @@ -1788,7 +1804,7 @@ class Client: else: raise InvalidArgument('target parameter must be either discord.Member or discord.Role') - r = yield from self.session.put(url, data=to_json(payload), headers=self.headers) + r = yield from self.session.put(url, data=utils.to_json(payload), headers=self.headers) log.debug(request_logging_format.format(method='PUT', response=r)) yield from utils._verify_successful_response(r) @@ -1824,3 +1840,72 @@ class Client: response = yield from self.session.delete(url, headers=self.headers) log.debug(request_logging_format.format(method='DELETE', response=response)) yield from utils._verify_successful_response(response) + + + # Voice management + + @asyncio.coroutine + def join_voice_channel(self, channel): + """|coro| + + Joins a voice channel and creates a :class:`VoiceClient` to + establish your connection to the voice server. + + Parameters + ---------- + channel : :class:`Channel` + The voice channel to join to. + + Raises + ------- + InvalidArgument + The channel was not a voice channel. + asyncio.TimeoutError + Could not connect to the voice channel in time. + ClientException + You are already connected to a voice channel. + + Returns + ------- + :class:`VoiceClient` + A voice client that is fully connected to the voice server. + """ + + if self._is_voice_connected.is_set(): + raise ClientException('Already connected to a voice channel') + + if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: + raise InvalidArgument('Channel passed must be a voice channel') + + self.voice_channel = channel + log.info('attempting to join voice channel {0.name}'.format(channel)) + + payload = { + 'op': 4, + 'd': { + 'guild_id': self.voice_channel.server.id, + 'channel_id': self.voice_channel.id, + 'self_mute': False, + 'self_deaf': False + } + } + + yield from self.ws.send(utils.to_json(payload)) + yield from asyncio.wait_for(self._session_id_found.wait(), timeout=5.0, loop=self.loop) + yield from asyncio.wait_for(self._voice_data_found.wait(), timeout=5.0, loop=self.loop) + + self._session_id_found.clear() + self._voice_data_found.clear() + + kwargs = { + 'user': self.user, + 'connected': self._is_voice_connected, + 'channel': self.voice_channel, + 'data': self._voice_data_found.data, + 'loop': self.loop, + 'session_id': self.session_id + } + + result = VoiceClient(**kwargs) + yield from result.connect() + return result diff --git a/discord/utils.py b/discord/utils.py index ff0355085..97584b312 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -29,6 +29,7 @@ from .errors import HTTPException, Forbidden, NotFound, InvalidArgument import datetime from base64 import b64encode import asyncio +import json def parse_time(timestamp): if timestamp: @@ -88,6 +89,9 @@ def _bytes_to_base64_data(data): b64 = b64encode(data).decode('ascii') return fmt.format(mime=mime, data=b64) +def to_json(obj): + return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) + try: create_task = asyncio.ensure_future except AttributeError: diff --git a/discord/voice_client.py b/discord/voice_client.py new file mode 100644 index 000000000..9f0ea3efc --- /dev/null +++ b/discord/voice_client.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +"""Some documentation to refer to: + +- Our main web socket (mWS) sends opcode 4 with a server ID and channel ID. +- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. +- We pull the session_id from VOICE_STATE_UPDATE. +- We pull the token, endpoint and guild_id from VOICE_SERVER_UPDATE. +- Then we initiate the voice web socket (vWS) pointing to the endpoint. +- We send opcode 0 with the user_id, guild_id, session_id and token using the vWS. +- The vWS sends back opcode 2 with an ssrc, port, modes(array) and hearbeat_interval. +- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. +- Then we send our IP and port via vWS with opcode 1. +- When that's all done, we receive opcode 4 from the vWS. +- Finally we can transmit data to endpoint:port. +""" + +import asyncio +import websockets +import socket +import json, time +import logging +import struct + +log = logging.getLogger(__name__) + +from . import utils +from .errors import ClientException + +class VoiceClient: + """Represents a Discord voice connection. + + This client is created solely through :meth:`Client.join_voice_channel` + and its only purpose is to transmit voice. + + Attributes + ----------- + session_id : str + The voice connection session ID. + token : str + The voice connection token. + user : :class:`User` + The user connected to voice. + endpoint : str + The endpoint we are connecting to. + channel : :class:`Channel` + The voice channel connected to. + """ + def __init__(self, user, connected, session_id, channel, data, loop): + self.user = user + self._connected = connected + self.channel = channel + self.session_id = session_id + self.loop = loop + self.token = data.get('token') + self.guild_id = data.get('guild_id') + self.endpoint = data.get('endpoint') + + @asyncio.coroutine + def keep_alive_handler(self, delay): + while True: + payload = { + 'op': 3, + 'd': int(time.time()) + } + + msg = 'Keeping voice websocket alive with timestamp {}' + log.debug(msg.format(payload['d'])) + yield from self.ws.send(utils.to_json(payload)) + yield from asyncio.sleep(delay) + + @asyncio.coroutine + def received_message(self, msg): + log.debug('Voice websocket frame received: {}'.format(msg)) + op = msg.get('op') + data = msg.get('d') + + if op == 2: + delay = (data['heartbeat_interval'] / 100.0) - 5 + self.keep_alive = utils.create_task(self.keep_alive_handler(delay), loop=self.loop) + yield from self.initial_connection(data) + elif op == 4: + yield from self.connection_ready(data) + + @asyncio.coroutine + def initial_connection(self, data): + self.ssrc = data.get('ssrc') + self.voice_port = data.get('port') + packet = bytearray(70) + struct.pack_into('>I', packet, 0, self.ssrc) + self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) + recv = yield from self.loop.sock_recv(self.socket, 70) + self.ip = [] + + for x in range(4, len(recv)): + val = recv[x] + if val == 0: + break + self.ip.append(str(val)) + + self.ip = '.'.join(self.ip) + self.port = recv[len(recv) - 2] << 0 | recv[len(recv) - 1] << 1 + + payload = { + 'op': 1, + 'd': { + 'protocol': 'udp', + 'data': { + 'address': self.ip, + 'port': self.port, + 'mode': 'plain' + } + } + } + + yield from self.ws.send(utils.to_json(payload)) + log.debug('sent {} to initialize voice connection'.format(payload)) + log.info('initial voice connection is done') + + @asyncio.coroutine + def connection_ready(self, data): + log.info('voice connection is now ready') + speaking = { + 'op': 5, + 'd': { + 'speaking': True, + 'delay': 0 + } + } + + yield from self.ws.send(utils.to_json(speaking)) + self._connected.set() + + @asyncio.coroutine + def connect(self): + log.info('voice connection is connecting...') + self.endpoint = self.endpoint.replace(':80', '') + self.endpoint_ip = socket.gethostbyname(self.endpoint) + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + + log.info('Voice endpoint found {0.endpoint} (IP: {0.endpoint_ip})'.format(self)) + self.ws = yield from websockets.connect('wss://' + self.endpoint, loop=self.loop) + self.ws.max_size = None + + payload = { + 'op': 0, + 'd': { + 'server_id': self.guild_id, + 'user_id': self.user.id, + 'session_id': self.session_id, + 'token': self.token + } + } + + yield from self.ws.send(utils.to_json(payload)) + + while not self._connected.is_set(): + msg = yield from self.ws.recv() + if msg is None: + yield from self.disconnect() + raise ClientException('Unexpected websocket close on voice websocket') + + yield from self.received_message(json.loads(msg)) + + @asyncio.coroutine + def disconnect(self): + """|coro| + + Disconnects all connections to the voice client. + + In order to reconnect, you must create another voice client + using :meth:`Client.join_voice_channel`. + """ + if not self._connected.is_set(): + return + + self.keep_alive.cancel() + self.socket.shutdown(socket.SHUT_RDWR) + self.socket.close() + self._connected.clear() + yield from self.ws.close()