From 1fba1b06faca31d07c9296b2badabfe22f173001 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 12 Jun 2016 20:32:59 -0400 Subject: [PATCH] Rewrite HTTP handling significantly. This should have a more uniform approach to rate limit handling. Instead of queueing every request, wait until we receive a 429 and then block the requesting bucket until we're done being rate limited. This should reduce the number of 429s done by the API significantly (about 66% avg). This also consistently checks for 502 retries across all requests. --- discord/client.py | 545 ++++++++------------------------------------- discord/gateway.py | 38 +--- discord/http.py | 484 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 586 insertions(+), 481 deletions(-) create mode 100644 discord/http.py diff --git a/discord/client.py b/discord/client.py index 7f3118f08..8221afe10 100644 --- a/discord/client.py +++ b/discord/client.py @@ -42,6 +42,7 @@ from .enums import ChannelType, ServerRegion from .voice_client import VoiceClient from .iterators import LogsFromIterator from .gateway import * +from .http import HTTPClient import asyncio import aiohttp @@ -52,7 +53,6 @@ import sys, re import tempfile, os, hashlib import itertools import datetime -from random import randint as random_integer from collections import namedtuple PY35 = sys.version_info >= (3, 5) @@ -136,16 +136,8 @@ class Client: self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop) - # Blame Jake for this - user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' - - self.headers = { - 'content-type': 'application/json', - 'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__) - } - connector = options.pop('connector', None) - self.session = aiohttp.ClientSession(loop=self.loop, connector=connector) + self.http = HTTPClient(connector, loop=self.loop) self._closed = asyncio.Event(loop=self.loop) self._is_logged_in = asyncio.Event(loop=self.loop) @@ -157,23 +149,21 @@ class Client: filename = hashlib.md5(email.encode('utf-8')).hexdigest() return os.path.join(tempfile.gettempdir(), 'discord_py', filename) - @asyncio.coroutine - def _login_via_cache(self, email, password): + def _get_cache_token(self, email, password): try: log.info('attempting to login via cache') cache_file = self._get_cache_filename(email) self.email = email with open(cache_file, 'r') as f: log.info('login cache file found') - self.token = f.read() - self.headers['authorization'] = self.token + return f.read() # at this point our check failed # so we have to login and get the proper token and then # redo the cache except OSError: log.info('a problem occurred while opening login cache') - pass # file not found et al + return None # file not found et al def _update_cache(self, email, password): try: @@ -222,20 +212,30 @@ class Client: @asyncio.coroutine def _resolve_destination(self, destination): - if isinstance(destination, (Channel, PrivateChannel, Server)): - return destination.id + if isinstance(destination, Channel): + return destination.id, destination.server.id + elif isinstance(destination, PrivateChannel): + return destination.id, None + elif isinstance(destination, Server): + return destination.id, destination.id elif isinstance(destination, User): found = self.connection._get_private_channel_by_user(destination.id) if found is None: # Couldn't find the user, so start a PM with them first. channel = yield from self.start_private_message(destination) - return channel.id + return channel.id, None else: - return found.id + return found.id, None elif isinstance(destination, Object): - return destination.id + found = self.get_channel(destination.id) + if found is not None: + return (yield from self._resolve_destination(found)) + + # couldn't find it in cache so YOLO + return destination.id, destination.id else: - raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') + fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}' + raise InvalidArgument(fmt.format(destination)) def __getattr__(self, name): if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'): @@ -291,55 +291,25 @@ class Client: @asyncio.coroutine def _login_1(self, token, **kwargs): log.info('logging in using static token') - self.token = token - self.email = None - if kwargs.pop('bot', True): - self.headers['authorization'] = 'Bot ' + self.token - else: - self.headers['authorization'] = self.token - - resp = yield from self.session.get(endpoints.ME, headers=self.headers) - yield from resp.release() - log.debug(request_logging_format.format(method='GET', response=resp)) - - if resp.status != 200: - if resp.status == 401: - raise LoginFailure('Improper token has been passed.') - else: - raise HTTPException(resp, None) - - log.info('token auth returned status code {}'.format(resp.status)) + yield from self.http.static_login(token, bot=kwargs.pop('bot', True)) self._is_logged_in.set() @asyncio.coroutine def _login_2(self, email, password, **kwargs): # attempt to read the token from cache if self.cache_auth: - yield from self._login_via_cache(email, password) - if self.is_logged_in: + token = self._get_cache_token() + try: + self.http.static_login(token, bot=False) + except: + log.info('cache auth token is out of date') + else: + self._is_logged_in.set() return - payload = { - 'email': email, - 'password': password - } - - data = utils.to_json(payload) - resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=resp)) - if resp.status != 200: - yield from resp.release() - if resp.status == 400: - raise LoginFailure('Improper credentials have been passed.') - else: - raise HTTPException(resp, None) - log.info('logging in returned status code {}'.format(resp.status)) + yield from self.http.email_login(email, password) self.email = email - - body = yield from resp.json(encoding='utf-8') - self.token = body['token'] - self.headers['authorization'] = self.token self._is_logged_in.set() # since we went through all this trouble @@ -395,12 +365,10 @@ class Client: def logout(self): """|coro| - Logs out of Discord and closes all connections.""" - response = yield from self.session.post(endpoints.LOGOUT, headers=self.headers) - yield from response.release() + Logs out of Discord and closes all connections. + """ yield from self.close() self._is_logged_in.clear() - log.debug(request_logging_format.format(method='POST', response=response)) @asyncio.coroutine def connect(self): @@ -453,7 +421,7 @@ class Client: yield from self.ws.close() - yield from self.session.close() + yield from self.http.close() self._closed.set() self._is_ready.clear() @@ -774,43 +742,11 @@ class Client: if not isinstance(user, User): raise InvalidArgument('user argument must be a User') - payload = { - 'recipient_id': user.id - } - - url = '{}/channels'.format(endpoints.ME) - r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=r)) - yield from utils._verify_successful_response(r) - data = yield from r.json(encoding='utf-8') - log.debug(request_success_log.format(response=r, json=payload, data=data)) + data = yield from self.http.start_private_message(user.id) channel = PrivateChannel(id=data['id'], user=user) self.connection._add_private_channel(channel) return channel - @asyncio.coroutine - def _retry_helper(self, name, *args, retries=0, **kwargs): - req_kwargs = {'headers': self.headers} - req_kwargs.update(kwargs) - resp = yield from self.session.request(*args, **req_kwargs) - tmp = request_logging_format.format(method=resp.method, response=resp) - log_fmt = 'In {}, {}'.format(name, tmp) - log.debug(log_fmt) - - if resp.status == 502 and retries < 5: - # retry the 502 request unconditionally - log.info('Retrying the 502 request to ' + name) - yield from asyncio.sleep(retries + 1) - return (yield from self._retry_helper(name, *args, retries=retries + 1, **kwargs)) - - if resp.status == 429: - retry = float(resp.headers['Retry-After']) / 1000.0 - yield from resp.release() - yield from asyncio.sleep(retry) - return (yield from self._retry_helper(name, *args, retries=retries, **kwargs)) - - return resp - @asyncio.coroutine def send_message(self, destination, content, *, tts=False): """|coro| @@ -858,23 +794,11 @@ class Client: The message that was sent. """ - channel_id = yield from self._resolve_destination(destination) + channel_id, guild_id = yield from self._resolve_destination(destination) content = str(content) - url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id) - payload = { - 'content': content, - 'nonce': random_integer(-2**63, 2**63 - 1) - } - - if tts: - payload['tts'] = True - - resp = yield from self._retry_helper('send_message', 'POST', url, data=utils.to_json(payload)) - yield from utils._verify_successful_response(resp) - data = yield from resp.json(encoding='utf-8') - log.debug(request_success_log.format(response=resp, json=payload, data=data)) + data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts) channel = self.get_channel(data.get('channel_id')) message = Message(channel=channel, **data) return message @@ -895,14 +819,8 @@ class Client: The location to send the typing update. """ - channel_id = yield from self._resolve_destination(destination) - - url = '{base}/{id}/typing'.format(base=endpoints.CHANNELS, id=channel_id) - - response = yield from self.session.post(url, headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + channel_id, guild_id = yield from self._resolve_destination(destination) + yield from self.http.send_typing(channel_id) @asyncio.coroutine def send_file(self, destination, fp, *, filename=None, content=None, tts=False): @@ -951,34 +869,18 @@ class Client: The message sent. """ - channel_id = yield from self._resolve_destination(destination) - - url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id) - form = aiohttp.FormData() - - if content is not None: - form.add_field('content', str(content)) - - form.add_field('tts', 'true' if tts else 'false') - - # we don't want the content-type json in this request - headers = self.headers.copy() - headers.pop('content-type', None) + channel_id, guild_id = yield from self._resolve_destination(destination) try: - # attempt to open the file and send the request with open(fp, 'rb') as f: - form.add_field('file', f, filename=filename, content_type='application/octet-stream') - response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers) + buffer = f.read() + if filename is None: + filename = fp except TypeError: - form.add_field('file', fp, filename=filename, content_type='application/octet-stream') - response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers) - - log.debug(request_logging_format.format(method='POST', response=response)) - yield from utils._verify_successful_response(response) - data = yield from response.json(encoding='utf-8') - msg = 'POST {0.url} returned {0.status} with {1} response' - log.debug(msg.format(response, data)) + buffer = fp + + data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id, + filename=filename, content=content, tts=tts) channel = self.get_channel(data.get('channel_id')) message = Message(channel=channel, **data) return message @@ -1004,12 +906,8 @@ class Client: HTTPException Deleting the message failed. """ - - url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, message.channel.id, message.id) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + channel = message.channel + yield from self.http.delete_message(channel.id, message.id, channel.server.id) @asyncio.coroutine def delete_messages(self, messages): @@ -1045,16 +943,9 @@ class Client: if len(messages) > 100 or len(messages) < 2: raise ClientException('Can only delete messages in the range of [2, 100]') - channel_id = messages[0].channel.id - url = '{0}/{1}/messages/bulk_delete'.format(endpoints.CHANNELS, channel_id) - payload = { - 'messages': [m.id for m in messages] - } - - response = yield from self.session.post(url, headers=self.headers, data=utils.to_json(payload)) - log.debug(request_logging_format.format(method='POST', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + channel = messages[0].channel + message_ids = [m.id for m in messages] + yield from self.http.delete_messages(channel.id, message_ids, channel.server.id) @asyncio.coroutine def purge_from(self, channel, *, limit=100, check=None, before=None, after=None): @@ -1179,19 +1070,9 @@ class Client: channel = message.channel content = str(new_content) - url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, channel.id, message.id) - payload = { - 'content': content - } - - response = yield from self._retry_helper('edit_message', 'PATCH', url, data=utils.to_json(payload)) - log.debug(request_logging_format.format(method='PATCH', response=response)) - yield from utils._verify_successful_response(response) - data = yield from response.json(encoding='utf-8') - log.debug(request_success_log.format(response=response, json=payload, data=data)) + data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=channel.server.id) return Message(channel=channel, **data) - @asyncio.coroutine def _logs_from(self, channel, limit=100, before=None, after=None): """|coro| @@ -1242,21 +1123,7 @@ class Client: if message.author == client.user: counter += 1 """ - url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id) - params = { - 'limit': limit - } - - if before: - params['before'] = before.id - if after: - params['after'] = after.id - - response = yield from self.session.get(url, params=params, headers=self.headers) - log.debug(request_logging_format.format(method='GET', response=response)) - yield from utils._verify_successful_response(response) - messages = yield from response.json(encoding='utf-8') - return messages + return self.http.logs_from(channel.id, limit, before=before, after=after) if PY35: def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False): @@ -1356,12 +1223,7 @@ class Client: HTTPException Kicking failed. """ - - url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.kick(member.id, member.server.id) @asyncio.coroutine def ban(self, member, delete_message_days=1): @@ -1390,16 +1252,7 @@ class Client: HTTPException Banning failed. """ - - params = { - 'delete-message-days': delete_message_days - } - - url = '{0}/{1.server.id}/bans/{1.id}'.format(endpoints.SERVERS, member) - response = yield from self.session.put(url, params=params, headers=self.headers) - log.debug(request_logging_format.format(method='PUT', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.ban(member.id, member.server.id, delete_message_days) @asyncio.coroutine def unban(self, server, user): @@ -1421,12 +1274,7 @@ class Client: HTTPException Unbanning failed. """ - - url = '{0}/{1.id}/bans/{2.id}'.format(endpoints.SERVERS, server, user) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.unban(user.id, server.id) @asyncio.coroutine def server_voice_state(self, member, *, mute=False, deafen=False): @@ -1456,17 +1304,7 @@ class Client: HTTPException The operation failed. """ - - url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) - payload = { - 'mute': mute, - 'deaf': deafen - } - - response = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.server_voice_state(member.id, member.server.id, mute=mute, deafen=deafen) @asyncio.coroutine def edit_profile(self, password=None, **fields): @@ -1527,30 +1365,21 @@ class Client: if not_bot_account and password is None: raise ClientException('Password is required for non-bot accounts.') - payload = { + args = { 'password': password, 'username': fields.get('username', self.user.name), 'avatar': avatar } if not_bot_account: - payload['email'] = fields.get('email', self.email) + args['email'] = fields.get('email', self.email) if 'new_password' in fields: - payload['new_password'] = fields['new_password'] - - - r = yield from self.session.patch(endpoints.ME, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - - data = yield from r.json(encoding='utf-8') - log.debug(request_success_log.format(response=r, json=payload, data=data)) + args['new_password'] = fields['new_password'] + yield from self.http.edit_profile(**args) if not_bot_account: - self.token = data['token'] self.email = data['email'] - self.headers['authorization'] = self.token if self.cache_auth: self._update_cache(self.email, password) @@ -1608,24 +1437,12 @@ class Client: Changing the nickname failed. """ + nickname = nickname if nickname else '' + if member == self.user: - fmt = '{0}/{1.server.id}/members/@me/nick' + yield from self.http.change_my_nickname(member.server.id, nickname) else: - fmt = '{0}/{1.server.id}/members/{1.id}' - - url = fmt.format(endpoints.SERVERS, member) - - payload = { - # oddly enough, this endpoint requires '' to clear the nickname - # instead of the more consistent 'null', this might change in the - # future, or not. - 'nick': nickname if nickname else '' - } - - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - yield from r.release() + yield from self.http.change_nickname(member.server.id, member.id, nickname) # Channel management @@ -1662,26 +1479,7 @@ class Client: Editing the channel failed. """ - url = '{0}/{1.id}'.format(endpoints.CHANNELS, channel) - payload = { - 'name': options.get('name', channel.name), - 'topic': options.get('topic', channel.topic), - } - - user_limit = options.get('user_limit') - if user_limit is not None: - payload['user_limit'] = user_limit - - bitrate = options.get('bitrate') - if bitrate is not None: - payload['bitrate'] = bitrate - - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - - data = yield from r.json(encoding='utf-8') - log.debug(request_success_log.format(response=r, json=payload, data=data)) + yield from self.http.edit_channel(channel.id, **options) @asyncio.coroutine def move_channel(self, channel, position): @@ -1735,13 +1533,7 @@ class Client: channels.insert(position, channel) payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)] - - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - - yield from r.release() - log.debug(request_success_log.format(json=payload, response=r, data={})) + yield from self.http.patch(url, json=payload, bucket='move_channel') @asyncio.coroutine def create_channel(self, server, name, type=None): @@ -1779,18 +1571,7 @@ class Client: if type is None: type = ChannelType.text - payload = { - 'name': name, - 'type': str(type) - } - - url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server) - response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=response)) - yield from utils._verify_successful_response(response) - - data = yield from response.json(encoding='utf-8') - log.debug(request_success_log.format(response=response, data=data, json=payload)) + data = yield from self.http.create_channel(server.id, name, str(type)) channel = Channel(server=server, **data) return channel @@ -1817,12 +1598,7 @@ class Client: HTTPException Deleting the channel failed. """ - - url = '{}/{}'.format(endpoints.CHANNELS, channel.id) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.delete_channel(channel.id) # Server management @@ -1847,12 +1623,7 @@ class Client: HTTPException If leaving the server failed. """ - - url = '{}/@me/guilds/{.id}'.format(endpoints.USERS, server) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.leave_server(server.id) @asyncio.coroutine def delete_server(self, server): @@ -1874,11 +1645,7 @@ class Client: You do not have permissions to delete the server. """ - url = '{0}/{1.id}'.format(endpoints.SERVERS, server) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.delete_server(server.id) @asyncio.coroutine def create_server(self, name, region=None, icon=None): @@ -1918,17 +1685,7 @@ class Client: else: region = region.name - payload = { - 'icon': icon, - 'name': name, - 'region': region - } - - r = yield from self.session.post(endpoints.SERVERS, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=r)) - yield from utils._verify_successful_response(r) - data = yield from r.json(encoding='utf-8') - log.debug(request_success_log.format(response=r, json=payload, data=data)) + data = yield from self.http.create_server(name, region, icon) return Server(**data) @asyncio.coroutine @@ -1984,30 +1741,18 @@ class Client: else: icon = None - payload = { - 'region': str(fields.get('region', server.region)), - 'afk_timeout': fields.get('afk_timeout', server.afk_timeout), - 'icon': icon, - 'name': fields.get('name', server.name), - } - - afk_channel = fields.get('afk_channel') - if afk_channel is None: - afk_channel = server.afk_channel - - payload['afk_channel'] = getattr(afk_channel, 'id', None) + fields['icon'] = icon + if 'afk_channel' in fields: + fields['afk_channel_id'] = fields['afk_channel'].id if 'owner' in fields: if server.owner != server.me: raise InvalidArgument('To transfer ownership you must be the owner of the server.') - payload['owner_id'] = fields['owner'].id + fields['owner_id'] = fields['owner'].id + + yield from self.http.edit_server(server.id, **fields) - url = '{0}/{1.id}'.format(endpoints.SERVERS, server) - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - yield from r.release() @asyncio.coroutine def get_bans(self, server): @@ -2036,11 +1781,7 @@ class Client: A list of :class:`User` that have been banned. """ - url = '{0}/{1.id}/bans'.format(endpoints.SERVERS, server) - resp = yield from self.session.get(url, headers=self.headers) - log.debug(request_logging_format.format(method='GET', response=resp)) - yield from utils._verify_successful_response(resp) - data = yield from resp.json(encoding='utf-8') + data = yield from self.http.get_bans(server.id) return [User(**user['user']) for user in data] # Invite management @@ -2092,20 +1833,7 @@ class Client: The invite that was created. """ - payload = { - 'max_age': options.get('max_age', 0), - 'max_uses': options.get('max_uses', 0), - 'temporary': options.get('temporary', False), - 'xkcdpass': options.get('xkcd', False) - } - - url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination) - response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=response)) - - yield from utils._verify_successful_response(response) - data = yield from response.json(encoding='utf-8') - log.debug(request_success_log.format(json=payload, response=response, data=data)) + data = yield from self.http.create_invite(destination.id, **options) self._fill_invite_data(data) return Invite(**data) @@ -2139,12 +1867,8 @@ class Client: The invite from the URL/ID. """ - destination = self._resolve_invite(url) - rurl = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) - response = yield from self.session.get(rurl, headers=self.headers) - log.debug(request_logging_format.format(method='GET', response=response)) - yield from utils._verify_successful_response(response) - data = yield from response.json(encoding='utf-8') + invite_id = self._resolve_invite(url) + data = yield from self.http.get_invite(invite_id) self._fill_invite_data(data) return Invite(**data) @@ -2174,11 +1898,7 @@ class Client: The list of invites that are currently active. """ - url = '{0}/{1.id}/invites'.format(endpoints.SERVERS, server) - resp = yield from self.session.get(url, headers=self.headers) - log.debug(request_logging_format.format(method='GET', response=resp)) - yield from utils._verify_successful_response(resp) - data = yield from resp.json(encoding='utf-8') + data = yield from self.http.invites_from(server.id) result = [] for invite in data: channel = server.get_channel(invite['channel']['id']) @@ -2210,12 +1930,8 @@ class Client: The invite is invalid or expired. """ - destination = self._resolve_invite(invite) - url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) - response = yield from self.session.post(url, headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + invite_id = self._resolve_invite(invite) + yield from self.http.accept_invite(invite_id) @asyncio.coroutine def delete_invite(self, invite): @@ -2241,12 +1957,8 @@ class Client: Revoking the invite failed. """ - destination = self._resolve_invite(invite) - url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + invite_id = self._resolve_invite(invite) + yield from self.http.delete_invite(invite_id) # Role management @@ -2298,13 +2010,7 @@ class Client: roles.append(role.id) payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] - - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - - data = yield from r.json() - log.debug(request_success_log.format(json=payload, response=r, data=data)) + yield from self.http.patch(url, json=payload, bucket='move_role') @asyncio.coroutine def edit_role(self, server, role, **fields): @@ -2345,11 +2051,6 @@ class Client: Editing the role failed. """ - url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role) - color = fields.get('color') - if color is None: - color = fields.get('colour', role.colour) - payload = { 'name': fields.get('name', role.name), 'permissions': fields.get('permissions', role.permissions).value, @@ -2358,12 +2059,7 @@ class Client: 'mentionable': fields.get('mentionable', role.mentionable) } - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - - data = yield from r.json(encoding='utf-8') - log.debug(request_success_log.format(json=payload, response=r, data=data)) + yield from self.http.edit_role(server.id, role.id, **payload) @asyncio.coroutine def delete_role(self, server, role): @@ -2386,24 +2082,11 @@ class Client: Deleting the role failed. """ - url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.delete_role(server.id, role.id) @asyncio.coroutine def _replace_roles(self, member, roles): - url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) - - payload = { - 'roles': roles - } - - r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=r)) - yield from utils._verify_successful_response(r) - yield from r.release() + yield from self.http.replace_roles(member.id, member.server.id, roles) @asyncio.coroutine def add_roles(self, member, *roles): @@ -2521,12 +2204,7 @@ class Client: is stored in cache. """ - url = '{0}/{1.id}/roles'.format(endpoints.SERVERS, server) - r = yield from self.session.post(url, headers=self.headers) - log.debug(request_logging_format.format(method='POST', response=r)) - yield from utils._verify_successful_response(r) - - data = yield from r.json(encoding='utf-8') + data = yield from self.http.create_role(server.id) role = Role(server=server, **data) # we have to call edit because you can't pass a payload to the @@ -2581,8 +2259,6 @@ class Client: or the target type was not :class:`Role` or :class:`Member`. """ - url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target) - allow = Permissions.none() if allow is None else allow deny = Permissions.none() if deny is None else deny @@ -2592,23 +2268,14 @@ class Client: deny = deny.value allow = allow.value - payload = { - 'id': target.id, - 'allow': allow, - 'deny': deny - } - if isinstance(target, Member): - payload['type'] = 'member' + perm_type = 'member' elif isinstance(target, Role): - payload['type'] = 'role' + perm_type = 'role' else: raise InvalidArgument('target parameter must be either discord.Member or discord.Role') - r = yield from self.session.put(url, data=utils.to_json(payload), headers=self.headers) - log.debug(request_logging_format.format(method='PUT', response=r)) - yield from utils._verify_successful_response(r) - yield from r.release() + yield from self.http.edit_channel_permissions(channel.id, target.id, allow, deny, perm_type) @asyncio.coroutine def delete_channel_permissions(self, channel, target): @@ -2637,12 +2304,7 @@ class Client: HTTPException Deleting channel specific permissions failed. """ - - url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target) - response = yield from self.session.delete(url, headers=self.headers) - log.debug(request_logging_format.format(method='DELETE', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.delete_channel_permissions(channel.id, target.id) # Voice management @@ -2676,18 +2338,10 @@ class Client: You do not have permissions to move the member. """ - url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) - if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: raise InvalidArgument('The channel provided must be a voice channel.') - payload = utils.to_json({ - 'channel_id': channel.id - }) - response = yield from self.session.patch(url, data=payload, headers=self.headers) - log.debug(request_logging_format.format(method='PATCH', response=response)) - yield from utils._verify_successful_response(response) - yield from response.release() + yield from self.http.move_member(member.id, member.server.id, channel.id) @asyncio.coroutine def join_voice_channel(self, channel): @@ -2817,10 +2471,7 @@ class Client: HTTPException Retrieving the information failed somehow. """ - url = '{}/@me'.format(endpoints.APPLICATIONS) - resp = yield from self.session.get(url, headers=self.headers) - yield from utils._verify_successful_response(resp) - data = yield from resp.json() + data = yield from self.http.application_info() return AppInfo(id=data['id'], name=data['name'], description=data['description'], icon=data['icon']) diff --git a/discord/gateway.py b/discord/gateway.py index 3a81b9914..382f6bca9 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -40,7 +40,7 @@ import struct log = logging.getLogger(__name__) -__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', +__all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket', 'KeepAliveHandler', 'VoiceKeepAliveHandler', 'DiscordVoiceWebSocket', 'ResumeWebSocket' ] @@ -97,36 +97,6 @@ class VoiceKeepAliveHandler(KeepAliveHandler): 'd': int(time.time() * 1000) } - -@asyncio.coroutine -def get_gateway(token, *, loop=None): - """Returns the gateway URL for connecting to the WebSocket. - - Parameters - ----------- - token : str - The discord authentication token. - loop - The event loop. - - Raises - ------ - GatewayNotFound - When the gateway is not returned gracefully. - """ - headers = { - 'authorization': token, - 'content-type': 'application/json' - } - - with aiohttp.ClientSession(loop=loop) as session: - resp = yield from session.get(endpoints.GATEWAY, headers=headers) - if resp.status != 200: - yield from resp.release() - raise GatewayNotFound() - data = yield from resp.json(encoding='utf-8') - return data.get('url') + '?encoding=json&v=4' - class DiscordWebSocket(websockets.client.WebSocketClientProtocol): """Implements a WebSocket for Discord's gateway v4. @@ -190,11 +160,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): This is for internal use only. """ - gateway = yield from get_gateway(client.token, loop=client.loop) + gateway = yield from client.http.get_gateway() ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) # dynamically add attributes needed - ws.token = client.token + ws.token = client.http.token ws._connection = client.connection ws._dispatch = client.dispatch ws.gateway = gateway @@ -505,7 +475,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): 'server_id': client.guild_id, 'user_id': client.user.id, 'session_id': client.session_id, - 'token': client.token + 'token': client.http.token } } diff --git a/discord/http.py b/discord/http.py new file mode 100644 index 000000000..15cd08ef9 --- /dev/null +++ b/discord/http.py @@ -0,0 +1,484 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import aiohttp +import asyncio +import json +import sys +import logging +import io +import inspect +import weakref +from random import randint as random_integer + +log = logging.getLogger(__name__) + +from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound +from . import utils, __version__ + +@asyncio.coroutine +def json_or_text(response): + text = yield from response.text(encoding='utf-8') + if response.headers['content-type'] == 'application/json': + return json.loads(text) + return text + +def _func_(): + # emulate __func__ from C++ + return inspect.currentframe().f_back.f_code.co_name + +class HTTPClient: + """Represents an HTTP client sending HTTP requests to the Discord API.""" + + BASE = 'https://discordapp.com' + API_BASE = BASE + '/api' + GATEWAY = API_BASE + '/gateway' + USERS = API_BASE + '/users' + ME = USERS + '/@me' + REGISTER = API_BASE + '/auth/register' + LOGIN = API_BASE + '/auth/login' + LOGOUT = API_BASE + '/auth/logout' + GUILDS = API_BASE + '/guilds' + CHANNELS = API_BASE + '/channels' + APPLICATIONS = API_BASE + '/oauth2/applications' + + SUCCESS_LOG = '{method} {url} with {json} has received {text}' + REQUEST_LOG = '{method} {url} has returned {status}' + + def __init__(self, connector=None, *, loop=None): + self.loop = asyncio.get_event_loop() if loop is None else loop + self.connector = connector + self.session = aiohttp.ClientSession(connector=connector, loop=self.loop) + self._locks = weakref.WeakValueDictionary() + self.token = None + self.bot_token = False + + user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' + self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__) + + @asyncio.coroutine + def request(self, method, url, *, bucket=None, **kwargs): + lock = self._locks.get(bucket) + if lock is None: + lock = asyncio.Lock(loop=self.loop) + if bucket is not None: + self._locks[bucket] = lock + + # header creation + headers = { + 'User-Agent': self.user_agent, + } + + if self.token is not None: + headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token + + # some checking if it's a JSON request + if 'json' in kwargs: + headers['Content-Type'] = 'application/json' + kwargs['data'] = utils.to_json(kwargs.pop('json')) + + kwargs['headers'] = headers + with (yield from lock): + for tries in range(5): + r = yield from self.session.request(method, url, **kwargs) + log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status)) + try: + # even errors have text involved in them so this is safe to call + data = yield from json_or_text(r) + + # the request was successful so just return the text/json + if 300 > r.status >= 200: + log.debug(self.SUCCESS_LOG.format(method=method, url=url, + json=kwargs.get('data'), text=data)) + return data + + # we are being rate limited + if r.status == 429: + fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"' + + # sleep a bit + retry_after = data['retry_after'] / 1000.0 + log.info(fmt.format(retry_after, bucket)) + yield from asyncio.sleep(retry_after) + continue + + # we've received a 502, unconditional retry + if r.status == 502 and tries <= 5: + yield from asyncio.sleep(1 + tries * 2) + continue + + # the usual error cases + if r.status == 403: + raise Forbidden(r, data) + elif r.status == 404: + raise NotFound(r, data) + else: + raise HTTPException(r, data) + finally: + # clean-up just in case + yield from r.release() + + def get(self, *args, **kwargs): + return self.request('GET', *args, **kwargs) + + def put(self, *args, **kwargs): + return self.request('PUT', *args, **kwargs) + + def patch(self, *args, **kwargs): + return self.request('PATCH', *args, **kwargs) + + def delete(self, *args, **kwargs): + return self.request('DELETE', *args, **kwargs) + + def post(self, *args, **kwargs): + return self.request('POST', *args, **kwargs) + + # state management + + @asyncio.coroutine + def close(self): + yield from self.session.close() + + def recreate(self): + self.session = aiohttp.ClientSession(self.connector, loop=self.loop) + + def _token(self, token, *, bot=True): + self.token = token + self.bot_token = bot + + # login management + + @asyncio.coroutine + def email_login(self, email, password): + payload = { + 'email': email, + 'password': password + } + + try: + data = yield from self.post(self.LOGIN, json=payload, bucket=_func_()) + except HTTPException as e: + if e.response.status == 400: + raise LoginFailure('Improper credentials have been passed.') from e + raise + + self._token(data['token'], bot=False) + return data + + @asyncio.coroutine + def static_login(self, token, *, bot): + old_state = (self.token, self.bot_token) + self._token(token, bot=bot) + + try: + data = yield from self.get(self.ME) + except HTTPException as e: + self._token(*old_state) + if e.response.status == 401: + raise LoginFailure('Improper token has been passed.') from e + raise e + + return data + + def logout(self): + return self.post(self.LOGOUT, bucket=_func_()) + + # Message management + + def start_private_message(self, user_id): + payload = { + 'recipient_id': user_id + } + + return self.post(self.ME + '/channels', json=payload, bucket=_func_()) + + def send_message(self, channel_id, content, *, guild_id=None, tts=False): + url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) + payload = { + 'content': str(content), + 'nonce': random_integer(-2**63, 2**63 - 1) + } + + if tts: + payload['tts'] = True + + return self.post(url, json=payload, bucket='messages:' + str(guild_id)) + + def send_typing(self, channel_id): + url = '{0.CHANNELS}/{1}/typing'.format(self, channel_id) + return self.post(url, bucket=_func_()) + + def send_file(self, channel_id, buffer, *, guild_id=None, filename=None, content=None, tts=False): + url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) + form = aiohttp.FormData() + + if content is not None: + form.add_field('content', str(content)) + + form.add_field('tts', 'true' if tts else 'false') + form.add_field('file', io.BytesIO(buffer), filename=filename, content_type='application/octet-stream') + + return self.post(url, data=form, bucket='messages:' + str(guild_id)) + + def delete_message(self, channel_id, message_id, guild_id=None): + url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) + bucket = '{}:{}'.format(_func_(), guild_id) + return self.delete(url, bucket=bucket) + + def delete_messages(self, channel_id, message_ids, guild_id=None): + url = '{0.CHANNELS}/{1}/messages/bulk_delete'.format(self, channel_id) + payload = { + 'messages': message_ids + } + bucket = '{}:{}'.format(_func_(), guild_id) + return self.post(url, json=payload, bucket=bucket) + + def edit_message(self, message_id, channel_id, content, *, guild_id=None): + url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) + payload = { + 'content': str(content) + } + return self.patch(url, json=payload, bucket='messages:' + str(guild_id)) + + + def logs_from(self, channel_id, limit, before=None, after=None): + url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) + params = { + 'limit': limit + } + + if before: + params['before'] = before + if after: + params['after'] = after + + return self.get(url, params=params, bucket=_func_()) + + # Member management + + def kick(self, user_id, guild_id): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + return self.delete(url, bucket=_func_()) + + def ban(self, user_id, guild_id, delete_message_days=1): + url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id) + params = { + 'delete-message-days': delete_message_days + } + return self.put(url, params=params, bucket=_func_()) + + def unban(self, user_id, guild_id): + url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id) + return self.delete(url, bucket=_func_()) + + def server_voice_state(self, user_id, guild_id, *, mute=False, deafen=False): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'mute': mute, + 'deafen': deafen + } + return self.patch(url, json=payload, bucket='members:' + str(guild_id)) + + def edit_profile(self, password, username, avatar, **fields): + payload = { + 'password': password, + 'username': username, + 'avatar': avatar + } + + if 'email' in fields: + payload['email'] = fields['email'] + + if 'new_password' in fields: + payload['new_password'] = fields['new_password'] + + return self.patch(self.ME, json=payload, bucket=_func_()) + + def change_my_nickname(self, guild_id, nickname): + url = '{0.GUILDS}/{1}/members/@me/nick'.format(self, guild_id) + payload = { + 'nick': nickname + } + bucket = '{}:{}'.format(_func_(), guild_id) + return self.patch(url, json=payload, bucket=bucket) + + def change_nickname(self, guild_id, user_id, nickname): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'nick': nickname + } + bucket = '{}:{}'.format(_func_(), guild_id) + return self.patch(url, json=payload, bucket=bucket) + + # Channel management + + def edit_channel(self, channel_id, **options): + url = '{0.CHANNELS}/{1}'.format(self, channel_id) + + valid_keys = ('name', 'topic', 'bitrate', 'user_limit') + payload = { + k: v for k, v in options.items() if k in valid_keys + } + + return self.patch(url, json=payload, bucket=_func_()) + + def create_channel(self, guild_id, name, channe_type): + url = '{0.GUILDS}/{1}/channels'.format(self, guild_id) + payload = { + 'name': name, + 'type': channe_type + } + + return self.post(url, json=payload, bucket=_func_()) + + def delete_channel(self, channel_id): + url = '{0.CHANNELS}/{1}'.format(self, channel_id) + return self.delete(url, bucket=_func_()) + + # Server management + + def leave_server(self, guild_id): + url = '{0.USERS}/@me/guilds/{1}'.format(self, guild_id) + return self.delete(url, bucket=_func_()) + + def delete_server(self, guild_id): + url = '{0.GUILDS}/{1}'.format(self, guild_id) + return self.delete(url, bucket=_func_()) + + def create_server(self, name, region, icon): + payload = { + 'name': name, + 'icon': icon, + 'region': region + } + + return self.post(self.GUILDS, json=payload, bucket=_func_()) + + def edit_server(self, guild_id, **fields): + valid_keys = ('name', 'region', 'icon', 'afk_timeout', 'owner_id', + 'afk_channel_id', 'splash', 'verification_level') + + payload = { + k: v for k, v in fields.items() if k in valid_keys + } + + url = '{0.GUILDS}/{1}'.format(self, guild_id) + return self.patch(url, json=payload, bucket=_func_()) + + def get_bans(self, guild_id): + url = '{0.GUILDS}/{1}/bans'.format(self, guild_id) + return self.get(url, bucket=_func_()) + + # Invite management + + def create_invite(self, channel_id, **options): + url = '{0.CHANNELS}/{1}/invites'.format(self, channel_id) + payload = { + 'max_age': options.get('max_age', 0), + 'max_uses': options.get('max_uses', 0), + 'temporary': options.get('temporary', False), + 'xkcdpass': options.get('xkcd', False) + } + + return self.post(url, json=payload, bucket=_func_()) + + def get_invite(self, invite_id): + url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) + return self.get(url, bucket=_func_()) + + def invites_from(self, guild_id): + url = '{0.GUILDS}/{1}/invites'.format(self, guild_id) + return self.get(url, bucket=_func_()) + + def accept_invite(self, invite_id): + url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) + return self.post(url, bucket=_func_()) + + def delete_invite(self, invite_id): + url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) + return self.delete(url, bucket=_func_()) + + # Role management + + def edit_role(self, guild_id, role_id, **fields): + url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id) + valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable') + payload = { + k: v for k, v in fields.items() if k in valid_keys + } + return self.patch(url, json=payload, bucket='roles:' + str(guild_id)) + + def delete_role(self, guild_id, role_id): + url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id) + return self.delete(url, bucket=_func_()) + + def replace_roles(self, user_id, guild_id, role_ids): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'roles': role_ids + } + return self.patch(url, json=payload, bucket='members:' + str(guild_id)) + + def create_role(self, guild_id): + url = '{0.GUILDS}/{1}/roles'.format(self, guild_id) + return self.post(url, bucket=_func_()) + + def edit_channel_permissions(self, channel_id, target, allow, deny, type): + url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target) + payload = { + 'id': target, + 'allow': allow, + 'deny': deny, + 'type': type + } + return self.put(url, json=payload, bucket=_func_()) + + def delete_channel_permissions(self, channel_id, target): + url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target) + return self.delete(url, bucket=_func_()) + + # Voice management + + def move_member(self, user_id, guild_id, channel_id): + url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) + payload = { + 'channel_id': channel_id + } + return self.patch(url, json=payload, bucket='members:' + str(guild_id)) + + # Misc + + def application_info(self): + url = '{0.APPLICATIONS}/@me'.format(self) + return self.get(url, bucket=_func_()) + + @asyncio.coroutine + def get_gateway(self): + try: + data = yield from self.get(self.GATEWAY, bucket=_func_()) + except HTTPException as e: + raise GatewayNotFound() from e + return data.get('url') + '?encoding=json&v=4'