From 29ea58d0080e0e6f4931fa9cb7fc4c04a2b248df Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 13 Dec 2015 01:42:15 -0500 Subject: [PATCH] Implement cache of login credentials. Also add endpoints.ME to easily access the @me endpoint. --- discord/client.py | 52 +++++++++++++++++++++++++++++++++++++++++--- discord/endpoints.py | 1 + 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/discord/client.py b/discord/client.py index 42657f695..75b3c9a38 100644 --- a/discord/client.py +++ b/discord/client.py @@ -45,6 +45,7 @@ import websockets import logging, traceback import sys, time, re, json +import tempfile, os, hashlib log = logging.getLogger(__name__) request_logging_format = '{method} {response.url} has returned {response.status}' @@ -68,6 +69,10 @@ class Client: loop : Optional[event loop]. The `event loop`_ to use for asynchronous operations. Defaults to ``None``, in which case the default event loop is used via ``asyncio.get_event_loop()``. + cache_auth : Optional[bool] + Indicates if :meth:`login` should cache the authentication tokens. Defaults + to ``True``. The method in which the cache is written is done by writing to + disk to a temporary directory. Attributes ----------- @@ -101,6 +106,7 @@ class Client: self.voice = None self.loop = asyncio.get_event_loop() if loop is None else loop self._listeners = [] + self.cache_auth = options.get('cache_auth', True) max_messages = options.get('max_messages') if max_messages is None or max_messages < 100: @@ -131,6 +137,10 @@ class Client: # internals + def _get_cache_filename(self, email): + filename = hashlib.md5(email.encode('utf-8')).hexdigest() + return os.path.join(tempfile.gettempdir(), 'discord_py', filename) + def handle_message(self, message): removed = [] for i, (condition, future) in enumerate(self._listeners): @@ -510,6 +520,31 @@ class Client: usually when it isn't 200 or the known incorrect credentials passing status code. """ + + # attempt to read the token from cache + if self.cache_auth: + try: + log.info('attempting to login via cache') + cache_file = self._get_cache_filename(email) + with open(cache_file, 'r') as f: + log.info('login cache file found') + self.token = f.read() + self.headers['authorization'] = self.token + + check = yield from self.session.get(endpoints.ME, headers=self.headers) + if check.status == 200: + log.info('login cache token check succeeded') + yield from check.release() + self._is_logged_in = True + return + + # at this point our check failed + # so we have to login and get the proper token and then + # redo the cache + except OSError as e: + log.info('a problem occurred while opening login cache') + pass # file not found et al + payload = { 'email': email, 'password': password @@ -531,6 +566,18 @@ class Client: self.headers['authorization'] = self.token self._is_logged_in = True + # since we went through all this trouble + # let's make sure we don't have to do it again + if self.cache_auth: + try: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + with open(cache_file, 'w') as f: + log.info('updating login cache') + f.write(self.token) + except OSError: + log.info('a problem occurred while updating the login cache') + pass + @asyncio.coroutine def logout(self): """|coro| @@ -683,7 +730,7 @@ class Client: 'recipient_id': user.id } - url = '{}/@me/channels'.format(endpoints.USERS) + url = '{}/channels'.format(endpoints.ME) r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) log.debug(request_logging_format.format(method='POST', response=r)) yield from utils._verify_successful_response(r) @@ -1216,8 +1263,7 @@ class Client: 'avatar': avatar } - url = '{0}/@me'.format(endpoints.USERS) - r = yield from self.session.patch(url, headers=self.headers, data=utils.to_json(payload)) + r = yield from self.session.patch(endpoints.ME, headers=self.headers, data=utils.to_json(payload)) log.debug(request_logging_format.format(method='PATCH', response=r)) yield from utils._verify_successful_response(r) diff --git a/discord/endpoints.py b/discord/endpoints.py index 7266b7751..d3e7f1979 100644 --- a/discord/endpoints.py +++ b/discord/endpoints.py @@ -28,6 +28,7 @@ BASE = 'https://discordapp.com' API_BASE = BASE + '/api' GATEWAY = API_BASE + '/gateway' USERS = API_BASE + '/users' +ME = USERS + '/@me' REGISTER = API_BASE + '/auth/register' LOGIN = API_BASE + '/auth/login' LOGOUT = API_BASE + '/auth/logout'