diff --git a/discord/client.py b/discord/client.py index 9e5f3c64f..76c2d8533 100644 --- a/discord/client.py +++ b/discord/client.py @@ -105,15 +105,15 @@ class Client: self.token = None self.gateway = None self.voice = None + self.session_id = None + self.sequence = 0 self.loop = asyncio.get_event_loop() if loop is None else loop self._listeners = [] self.cache_auth = options.get('cache_auth', True) - max_messages = options.get('max_messages') - if max_messages is None or max_messages < 100: - max_messages = 5000 - - self.connection = ConnectionState(self.dispatch, max_messages) + self.max_messages = options.get('max_messages') + if self.max_messages is None or self.max_messages < 100: + self.max_messages = 5000 # Blame React for this user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' @@ -180,7 +180,6 @@ class Client: log.info('a problem occurred while updating the login cache') pass - def handle_message(self, message): removed = [] for i, (condition, future) in enumerate(self._listeners): @@ -311,6 +310,7 @@ class Client: print('Ignoring exception in {}'.format(event_method), file=sys.stderr) traceback.print_exc() + @asyncio.coroutine def received_message(self, msg): log.debug('WebSocket Event: {}'.format(msg)) self.dispatch('socket_response', msg) @@ -318,6 +318,15 @@ class Client: op = msg.get('op') data = msg.get('d') + if 's' in msg: + self.sequence = msg['s'] + + if op == 7: + # redirect op code + yield from self.ws.close() + yield from self.redirect_websocket(data.get('url')) + return + if op != 0: log.info('Unhandled op {}'.format(op)) return @@ -325,6 +334,10 @@ class Client: event = msg.get('t') if event == 'READY': + self.connection = ConnectionState(self.dispatch, self.max_messages) + self.session_id = data['session_id'] + + if event == 'READY' or event == 'RESUMED': interval = data['heartbeat_interval'] / 1000.0 self.keep_alive = utils.create_task(self.keep_alive_handler(interval), loop=self.loop) @@ -352,30 +365,60 @@ class Client: log.info("Unhandled event {}".format(event)) @asyncio.coroutine - def _make_websocket(self): + def _make_websocket(self, initial=True): if not self.is_logged_in: raise ClientException('You must be logged in to connect') self.ws = yield from websockets.connect(self.gateway, loop=self.loop) self.ws.max_size = None log.info('Created websocket connected to {0.gateway}'.format(self)) + + if initial: + payload = { + 'op': 2, + 'd': { + 'token': self.token, + 'properties': { + '$os': sys.platform, + '$browser': 'discord.py', + '$device': 'discord.py', + '$referrer': '', + '$referring_domain': '' + }, + 'v': 3 + } + } + + yield from self.ws.send(utils.to_json(payload)) + log.info('sent the initial payload to create the websocket') + + @asyncio.coroutine + def redirect_websocket(self, url): + # if we get redirected then we need to recreate the websocket + # when this recreation happens we have to try to do a reconnection + log.info('redirecting websocket from {} to {}'.format(self.gateway, url)) + self.keep_alive_handler.cancel() + + self.gateway = url + yield from self._make_websocket(initial=False) + yield from self._reconnect_ws() + + if self.is_voice_connected(): + # update the websocket reference pointed to by voice + self.voice.main_ws = self.ws + + @asyncio.coroutine + def _reconnect_ws(self): payload = { - 'op': 2, + 'op': 6, 'd': { - 'token': self.token, - 'properties': { - '$os': sys.platform, - '$browser': 'discord.py', - '$device': 'discord.py', - '$referrer': '', - '$referring_domain': '' - }, - 'v': 3 + 'session_id': self.session_id, + 'seq': self.sequence } } + log.info('sending reconnection frame to websocket {}'.format(payload)) yield from self.ws.send(utils.to_json(payload)) - log.info('sent the initial payload to create the websocket') # properties @@ -636,10 +679,14 @@ class Client: while not self._closed: msg = yield from self.ws.recv() if msg is None: - yield from self.close() - break + if self.ws.close_code == 1012: + yield from self.redirect_websocket(self.gateway) + continue + else: + yield from self.close() + break - self.received_message(json.loads(msg)) + yield from self.received_message(json.loads(msg)) @asyncio.coroutine def close(self): @@ -654,11 +701,12 @@ class Client: yield from self.voice.disconnect() self.voice = None - yield from self.ws.close() + if self.ws.open: + yield from self.ws.close() + self.keep_alive.cancel() self._closed = True - @asyncio.coroutine def start(self, email, password): """|coro|