diff --git a/discord/__init__.py b/discord/__init__.py index 17e640b0c..555d73ffe 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -30,8 +30,7 @@ from .role import Role from .colour import Color, Colour from .invite import Invite from .object import Object -from . import utils -from . import opus +from . import utils, opus, compat from .voice_client import VoiceClient from .enums import ChannelType, ServerRegion, Status from collections import namedtuple diff --git a/discord/client.py b/discord/client.py index 8e5f1c041..d8dd9b12d 100644 --- a/discord/client.py +++ b/discord/client.py @@ -28,7 +28,6 @@ from . import __version__ as library_version from . import endpoints from .user import User from .member import Member -from .game import Game from .channel import Channel, PrivateChannel from .server import Server from .message import Message @@ -38,10 +37,11 @@ from .role import Role from .errors import * from .state import ConnectionState from .permissions import Permissions -from . import utils +from . import utils, compat from .enums import ChannelType, ServerRegion, Status from .voice_client import VoiceClient from .iterators import LogsFromIterator +from .gateway import * import asyncio import aiohttp @@ -51,7 +51,6 @@ import logging, traceback import sys, time, re, json import tempfile, os, hashlib import itertools -import zlib from random import randint as random_integer PY35 = sys.version_info >= (3, 5) @@ -115,11 +114,7 @@ class Client: def __init__(self, *, loop=None, **options): self.ws = None self.token = None - self.gateway = None self.voice = None - self.session_id = None - self.keep_alive = None - self.sequence = 0 self.loop = asyncio.get_event_loop() if loop is None else loop self._listeners = [] self.cache_auth = options.get('cache_auth', True) @@ -156,11 +151,6 @@ class Client: filename = hashlib.md5(email.encode('utf-8')).hexdigest() return os.path.join(tempfile.gettempdir(), 'discord_py', filename) - @asyncio.coroutine - def _send_ws(self, data): - self.dispatch('socket_raw_send', data) - yield from self.ws.send(data) - @asyncio.coroutine def _login_via_cache(self, email, password): try: @@ -254,14 +244,6 @@ class Client: else: object.__setattr__(self, name, value) - @asyncio.coroutine - def _get_gateway(self): - resp = yield from self.session.get(endpoints.GATEWAY, headers=self.headers) - if resp.status != 200: - raise GatewayNotFound() - data = yield from resp.json() - return data.get('url') - @asyncio.coroutine def _run_event(self, event, *args, **kwargs): try: @@ -283,23 +265,7 @@ class Client: getattr(self, handler)(*args, **kwargs) if hasattr(self, method): - utils.create_task(self._run_event(method, *args, **kwargs), loop=self.loop) - - @asyncio.coroutine - def keep_alive_handler(self, interval): - try: - while not self.is_closed: - payload = { - 'op': 1, - 'd': int(time.time()) - } - - msg = 'Keeping websocket alive with timestamp {}' - log.debug(msg.format(payload['d'])) - yield from self._send_ws(utils.to_json(payload)) - yield from asyncio.sleep(interval) - except asyncio.CancelledError: - pass + compat.create_task(self._run_event(method, *args, **kwargs), loop=self.loop) @asyncio.coroutine def on_error(self, event_method, *args, **kwargs): @@ -352,7 +318,7 @@ class Client: if is_ready or event == 'RESUMED': interval = data['heartbeat_interval'] / 1000.0 - self.keep_alive = utils.create_task(self.keep_alive_handler(interval), loop=self.loop) + self.keep_alive = compat.create_task(self.keep_alive_handler(interval), loop=self.loop) if event == 'VOICE_STATE_UPDATE': user_id = data.get('user_id') @@ -380,64 +346,6 @@ class Client: else: result = func(data) - @asyncio.coroutine - def _make_websocket(self, initial=True): - if not self.is_logged_in: - raise ClientException('You must be logged in to connect') - - self.ws = yield from websockets.connect(self.gateway, loop=self.loop) - self.ws.max_size = None - log.info('Created websocket connected to {0.gateway}'.format(self)) - - if initial: - payload = { - 'op': 2, - 'd': { - 'token': self.token, - 'properties': { - '$os': sys.platform, - '$browser': 'discord.py', - '$device': 'discord.py', - '$referrer': '', - '$referring_domain': '' - }, - 'compress': True, - 'large_threshold': 250, - 'v': 3 - } - } - - yield from self._send_ws(utils.to_json(payload)) - log.info('sent the initial payload to create the websocket') - - @asyncio.coroutine - def redirect_websocket(self, url): - # if we get redirected then we need to recreate the websocket - # when this recreation happens we have to try to do a reconnection - log.info('redirecting websocket from {} to {}'.format(self.gateway, url)) - self.keep_alive_handler.cancel() - - self.gateway = url - yield from self._make_websocket(initial=False) - yield from self._reconnect_ws() - - if self.is_voice_connected(): - # update the websocket reference pointed to by voice - self.voice.main_ws = self.ws - - @asyncio.coroutine - def _reconnect_ws(self): - payload = { - 'op': 6, - 'd': { - 'session_id': self.session_id, - 'seq': self.sequence - } - } - - log.info('sending reconnection frame to websocket {}'.format(payload)) - yield from self._send_ws(utils.to_json(payload)) - # login state management @asyncio.coroutine @@ -553,29 +461,24 @@ class Client: Raises ------- - ClientException - If this is called before :meth:`login` was invoked successfully - or when an unexpected closure of the websocket occurs. GatewayNotFound If the gateway to connect to discord is not found. Usually if this is thrown then there is a discord API outage. + ConnectionClosed + The websocket connection has been terminated. """ - self.gateway = yield from self._get_gateway() - yield from self._make_websocket() + self.ws = yield from DiscordWebSocket.from_client(self) while not self.is_closed: - msg = yield from self.ws.recv() - if msg is None: - if self.ws.close_code == 1012: - yield from self.redirect_websocket(self.gateway) - continue - elif not self._is_ready.is_set(): - raise ClientException('Unexpected websocket closure received') - else: - yield from self.close() - break - - yield from self.received_message(msg) + try: + yield from self.ws.poll_event() + except ReconnectWebSocket: + log.info('Reconnecting the websocket.') + self.ws = yield from DiscordWebSocket.from_client(self) + except ConnectionClosed as e: + yield from self.close() + if e.code != 1000: + raise @asyncio.coroutine def close(self): @@ -593,9 +496,6 @@ class Client: if self.ws is not None and self.ws.open: yield from self.ws.close() - if self.keep_alive is not None: - self.keep_alive.cancel() - yield from self.session.close() self._closed.set() self._is_ready.clear() @@ -1317,7 +1217,7 @@ class Client: } } - yield from self._send_ws(utils.to_json(payload)) + yield from self.ws.send_as_json(payload) @asyncio.coroutine def kick(self, member): @@ -1568,32 +1468,7 @@ class Client: InvalidArgument If the ``game`` parameter is not :class:`Game` or None. """ - - if game is not None and not isinstance(game, Game): - raise InvalidArgument('game must be of Game or None') - - idle_since = None if idle == False else int(time.time() * 1000) - sent_game = game and {'name': game.name} - - payload = { - 'op': 3, - 'd': { - 'game': sent_game, - 'idle_since': idle_since - } - } - - sent = utils.to_json(payload) - log.debug('Sending "{}" to change status'.format(sent)) - yield from self._send_ws(sent) - for server in self.servers: - me = server.me - if me is None: - continue - - me.game = game - status = Status.idle if idle_since else Status.online - me.status = status + yield from self.ws.change_presence(game=game, idle=idle) # Channel management diff --git a/discord/endpoints.py b/discord/endpoints.py index 0ef0efa96..ac99f74ff 100644 --- a/discord/endpoints.py +++ b/discord/endpoints.py @@ -26,7 +26,7 @@ DEALINGS IN THE SOFTWARE. BASE = 'https://discordapp.com' API_BASE = BASE + '/api' -GATEWAY = API_BASE + '/gateway' +GATEWAY = API_BASE + '/gateway?encoding=json&v=4' USERS = API_BASE + '/users' ME = USERS + '/@me' REGISTER = API_BASE + '/auth/register' diff --git a/discord/errors.py b/discord/errors.py index 62fff0b23..746e17400 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -101,3 +101,21 @@ class LoginFailure(ClientException): failure. """ pass + +class ConnectionClosed(ClientException): + """Exception that's thrown when the gateway connection is + closed for reasons that could not be handled internally. + + Attributes + ----------- + code : int + The close code of the websocket. + reason : str + The reason provided for the closure. + """ + def __init__(self, original): + # This exception is just the same exception except + # reconfigured to subclass ClientException for users + self.code = original.code + self.reason = original.reason + super().__init__(str(original)) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 9b4e44ef2..3f4fa34dc 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -232,7 +232,7 @@ class Bot(GroupMixin, discord.Client): if ev in self.extra_events: for event in self.extra_events[ev]: coro = self._run_extra(event, event_name, *args, **kwargs) - discord.utils.create_task(coro, loop=self.loop) + discord.compat.create_task(coro, loop=self.loop) # utility "send_*" functions diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index b308bbc1a..cf0658736 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -142,9 +142,9 @@ class Command: injected = inject_context(ctx, coro) if self.instance is not None: - discord.utils.create_task(injected(self.instance, error, ctx), loop=ctx.bot.loop) + discord.compat.create_task(injected(self.instance, error, ctx), loop=ctx.bot.loop) else: - discord.utils.create_task(injected(error, ctx), loop=ctx.bot.loop) + discord.compat.create_task(injected(error, ctx), loop=ctx.bot.loop) def _get_from_servers(self, bot, getter, argument): result = None diff --git a/discord/gateway.py b/discord/gateway.py new file mode 100644 index 000000000..2b4fc4dc4 --- /dev/null +++ b/discord/gateway.py @@ -0,0 +1,402 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import sys +import websockets +import asyncio +import aiohttp +from . import utils, endpoints, compat +from .enums import Status +from .game import Game +from .errors import GatewayNotFound, ConnectionClosed, InvalidArgument +import logging +import zlib, time, json +from collections import namedtuple +import threading + +log = logging.getLogger(__name__) + +__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', + 'KeepAliveHandler' ] + +class ReconnectWebSocket(Exception): + """Signals to handle the RECONNECT opcode.""" + pass + +EventListener = namedtuple('EventListener', 'predicate event result future') + +class KeepAliveHandler(threading.Thread): + def __init__(self, *args, **kwargs): + ws = kwargs.pop('ws', None) + interval = kwargs.pop('interval', None) + threading.Thread.__init__(self, *args, **kwargs) + self.ws = ws + self.interval = interval + self.daemon = True + self._stop = threading.Event() + + def run(self): + while not self._stop.wait(self.interval): + data = self.get_payload() + msg = 'Keeping websocket alive with sequence {0[d]}'.format(data) + log.debug(msg) + coro = self.ws.send_as_json(data) + f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop) + try: + # block until sending is complete + f.result() + except Exception: + self.stop() + + def get_payload(self): + return { + 'op': self.ws.HEARTBEAT, + 'd': self.ws._connection.sequence + } + + def stop(self): + self._stop.set() + + +@asyncio.coroutine +def get_gateway(token, *, loop=None): + """Returns the gateway URL for connecting to the WebSocket. + + Parameters + ----------- + token : str + The discord authentication token. + loop + The event loop. + + Raises + ------ + GatewayNotFound + When the gateway is not returned gracefully. + """ + headers = { + 'authorization': token, + 'content-type': 'application/json' + } + + with aiohttp.ClientSession(loop=loop) as session: + resp = yield from session.get(endpoints.GATEWAY, headers=headers) + if resp.status != 200: + yield from resp.release() + raise GatewayNotFound() + data = yield from resp.json() + return data.get('url') + +class DiscordWebSocket(websockets.client.WebSocketClientProtocol): + """Implements a WebSocket for Discord's gateway v4. + + This is created through :func:`create_main_websocket`. Library + users should never create this manually. + + Attributes + ----------- + DISPATCH + Receive only. Denotes an event to be sent to Discord, such as READY. + HEARTBEAT + When received tells Discord to keep the connection alive. + When sent asks if your connection is currently alive. + IDENTIFY + Send only. Starts a new session. + PRESENCE + Send only. Updates your presence. + VOICE_STATE + Send only. Starts a new connection to a voice server. + VOICE_PING + Send only. Checks ping time to a voice server, do not use. + RESUME + Send only. Resumes an existing connection. + RECONNECT + Receive only. Tells the client to reconnect to a new gateway. + REQUEST_MEMBERS + Send only. Asks for the full member list of a server. + INVALIDATE_SESSION + Receive only. Tells the client to invalidate the session and IDENTIFY + again. + gateway + The gateway we are currently connected to. + token + The authentication token for discord. + """ + + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE = 3 + VOICE_STATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQUEST_MEMBERS = 8 + INVALIDATE_SESSION = 9 + + def __init__(self, *args, **kwargs): + super().__init__(*args, max_size=None, **kwargs) + # an empty dispatcher to prevent crashes + self._dispatch = lambda *args: None + # generic event listeners + self._dispatch_listeners = [] + # the keep alive + self._keep_alive = None + + @classmethod + @asyncio.coroutine + def connect(cls, dispatch, *, token=None, connection=None, loop=None): + """Creates a main websocket for Discord used for the client. + + Parameters + ---------- + token : str + The token for Discord authentication. + connection + The ConnectionState for the client. + dispatch + The function that dispatches events. + loop + The event loop to use. + + Returns + ------- + DiscordWebSocket + A websocket connected to Discord. + """ + + gateway = yield from get_gateway(token, loop=loop) + ws = yield from websockets.connect(gateway, loop=loop, klass=cls) + + # dynamically add attributes needed + ws.token = token + ws._connection = connection + ws._dispatch = dispatch + ws.gateway = gateway + + log.info('Created websocket connected to {}'.format(gateway)) + yield from ws.identify() + log.info('sent the identify payload to create the websocket') + return ws + + @classmethod + def from_client(cls, client): + """Creates a main websocket for Discord from a :class:`Client`. + + This is for internal use only. + """ + return cls.connect(client.dispatch, token=client.token, + connection=client.connection, + loop=client.loop) + + def wait_for(self, event, predicate, result): + """Waits for a DISPATCH'd event that meets the predicate. + + Parameters + ----------- + event : str + The event name in all upper case to wait for. + predicate + A function that takes a data parameter to check for event + properties. The data parameter is the 'd' key in the JSON message. + result + A function that takes the same data parameter and executes to send + the result to the future. + + Returns + -------- + asyncio.Future + A future to wait for. + """ + + future = asyncio.Future(loop=self.loop) + entry = EventListener(event=event, predicate=predicate, result=result, future=future) + self._dispatch_listeners.append(entry) + return future + + @asyncio.coroutine + def identify(self): + """Sends the IDENTIFY packet.""" + payload = { + 'op': self.IDENTIFY, + 'd': { + 'token': self.token, + 'properties': { + '$os': sys.platform, + '$browser': 'discord.py', + '$device': 'discord.py', + '$referrer': '', + '$referring_domain': '' + }, + 'compress': True, + 'large_threshold': 250, + 'v': 3 + } + } + yield from self.send_as_json(payload) + + @asyncio.coroutine + def received_message(self, msg): + self._dispatch('socket_raw_receive', msg) + + if isinstance(msg, bytes): + msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB + msg = msg.decode('utf-8') + + msg = json.loads(msg) + + log.debug('WebSocket Event: {}'.format(msg)) + self._dispatch('socket_response', msg) + + op = msg.get('op') + data = msg.get('d') + + if 's' in msg: + self._connection.sequence = msg['s'] + + if op == self.RECONNECT: + # "reconnect" can only be handled by the Client + # so we terminate our connection and raise an + # internal exception signalling to reconnect. + yield from self.close() + raise ReconnectWebSocket() + + if op == self.INVALIDATE_SESSION: + self._connection.sequence = None + self._connection.session_id = None + return + + if op != self.DISPATCH: + log.info('Unhandled op {}'.format(op)) + return + + event = msg.get('t') + is_ready = event == 'READY' + + if is_ready: + self._connection.clear() + self._connection.sequence = msg['s'] + self._connection.session_id = data['session_id'] + + if is_ready or event == 'RESUMED': + interval = data['heartbeat_interval'] / 1000.0 + self._keep_alive = KeepAliveHandler(ws=self, interval=interval) + self._keep_alive.start() + + parser = 'parse_' + event.lower() + + try: + func = getattr(self._connection, parser) + except AttributeError: + log.info('Unhandled event {}'.format(event)) + else: + func(data) + + # remove the dispatched listeners + removed = [] + for index, entry in enumerate(self._dispatch_listeners): + if entry.event != event: + continue + + future = entry.future + if future.cancelled(): + removed.append(index) + + try: + valid = entry.predicate(data) + except Exception as e: + future.set_exception(e) + removed.append(index) + else: + if valid: + future.set_result(entry.result) + removed.append(index) + + for index in reversed(removed): + del self._dispatch_listeners[index] + + @asyncio.coroutine + def poll_event(self): + """Polls for a DISPATCH event and handles the general gateway loop. + + Raises + ------ + ConnectionClosed + The websocket connection was terminated for unhandled reasons. + """ + try: + msg = yield from self.recv() + yield from self.received_message(msg) + except websockets.exceptions.ConnectionClosed as e: + if e.code in (4008, 4009) or e.code in range(1001, 1015): + raise ReconnectWebSocket() from e + else: + raise ConnectionClosed(e) from e + + @asyncio.coroutine + def send(self, data): + self._dispatch('socket_raw_send', data) + yield from super().send(data) + + @asyncio.coroutine + def send_as_json(self, data): + yield from super().send(utils.to_json(data)) + + @asyncio.coroutine + def change_presence(self, *, game=None, idle=None): + if game is not None and not isinstance(game, Game): + raise InvalidArgument('game must be of Game or None') + + idle_since = None if idle == False else int(time.time() * 1000) + sent_game = game and {'name': game.name} + + payload = { + 'op': self.PRESENCE, + 'd': { + 'game': sent_game, + 'idle_since': idle_since + } + } + + sent = utils.to_json(payload) + log.debug('Sending "{}" to change status'.format(sent)) + yield from self.send(sent) + + for server in self._connection.servers: + me = server.me + if me is None: + continue + + me.game = game + status = Status.idle if idle_since else Status.online + me.status = status + + @asyncio.coroutine + def close(self, code=1000, reason=''): + if self._keep_alive: + self._keep_alive.stop() + + yield from super().close(code, reason) diff --git a/discord/state.py b/discord/state.py index fc62927c9..fdffd444e 100644 --- a/discord/state.py +++ b/discord/state.py @@ -31,7 +31,7 @@ from .message import Message from .channel import Channel, PrivateChannel from .member import Member from .role import Role -from . import utils +from . import utils, compat from .enums import Status @@ -59,6 +59,8 @@ class ConnectionState: def clear(self): self.user = None + self.sequence = None + self.session_id = None self._servers = {} self._private_channels = {} # extra dict to look up private channels by user id @@ -180,7 +182,7 @@ class ConnectionState: self._add_private_channel(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) - utils.create_task(self._delay_ready(), loop=self.loop) + compat.create_task(self._delay_ready(), loop=self.loop) def parse_message_create(self, data): channel = self.get_channel(data.get('channel_id')) @@ -378,7 +380,7 @@ class ConnectionState: # since we're not waiting for 'useful' READY we'll just # do the chunk request here - utils.create_task(self._chunk_and_dispatch(server, unavailable), loop=self.loop) + compat.create_task(self._chunk_and_dispatch(server, unavailable), loop=self.loop) return # Dispatch available if newly available diff --git a/docs/api.rst b/docs/api.rst index af746be35..91d66ab26 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -526,6 +526,8 @@ The following exceptions are thrown by the library. .. autoexception:: GatewayNotFound +.. autoexception:: ConnectionClosed + .. autoexception:: discord.opus.OpusError .. autoexception:: discord.opus.OpusNotLoaded diff --git a/requirements.txt b/requirements.txt index 9da1be564..9221d17ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiohttp>=0.21.0,<0.22.0 -websockets==2.7 +websockets==3.1 PyNaCl==1.0.1