Browse Source

Rewrite HTTP handling significantly.

This should have a more uniform approach to rate limit handling. Instead
of queueing every request, wait until we receive a 429 and then block
the requesting bucket until we're done being rate limited. This should
reduce the number of 429s done by the API significantly (about 66% avg).

This also consistently checks for 502 retries across all requests.
pull/244/head
Rapptz 9 years ago
parent
commit
1fba1b06fa
  1. 545
      discord/client.py
  2. 38
      discord/gateway.py
  3. 484
      discord/http.py

545
discord/client.py

@ -42,6 +42,7 @@ from .enums import ChannelType, ServerRegion
from .voice_client import VoiceClient
from .iterators import LogsFromIterator
from .gateway import *
from .http import HTTPClient
import asyncio
import aiohttp
@ -52,7 +53,6 @@ import sys, re
import tempfile, os, hashlib
import itertools
import datetime
from random import randint as random_integer
from collections import namedtuple
PY35 = sys.version_info >= (3, 5)
@ -136,16 +136,8 @@ class Client:
self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop)
# Blame Jake for this
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
self.headers = {
'content-type': 'application/json',
'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__)
}
connector = options.pop('connector', None)
self.session = aiohttp.ClientSession(loop=self.loop, connector=connector)
self.http = HTTPClient(connector, loop=self.loop)
self._closed = asyncio.Event(loop=self.loop)
self._is_logged_in = asyncio.Event(loop=self.loop)
@ -157,23 +149,21 @@ class Client:
filename = hashlib.md5(email.encode('utf-8')).hexdigest()
return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
@asyncio.coroutine
def _login_via_cache(self, email, password):
def _get_cache_token(self, email, password):
try:
log.info('attempting to login via cache')
cache_file = self._get_cache_filename(email)
self.email = email
with open(cache_file, 'r') as f:
log.info('login cache file found')
self.token = f.read()
self.headers['authorization'] = self.token
return f.read()
# at this point our check failed
# so we have to login and get the proper token and then
# redo the cache
except OSError:
log.info('a problem occurred while opening login cache')
pass # file not found et al
return None # file not found et al
def _update_cache(self, email, password):
try:
@ -222,20 +212,30 @@ class Client:
@asyncio.coroutine
def _resolve_destination(self, destination):
if isinstance(destination, (Channel, PrivateChannel, Server)):
return destination.id
if isinstance(destination, Channel):
return destination.id, destination.server.id
elif isinstance(destination, PrivateChannel):
return destination.id, None
elif isinstance(destination, Server):
return destination.id, destination.id
elif isinstance(destination, User):
found = self.connection._get_private_channel_by_user(destination.id)
if found is None:
# Couldn't find the user, so start a PM with them first.
channel = yield from self.start_private_message(destination)
return channel.id
return channel.id, None
else:
return found.id
return found.id, None
elif isinstance(destination, Object):
return destination.id
found = self.get_channel(destination.id)
if found is not None:
return (yield from self._resolve_destination(found))
# couldn't find it in cache so YOLO
return destination.id, destination.id
else:
raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object')
fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}'
raise InvalidArgument(fmt.format(destination))
def __getattr__(self, name):
if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'):
@ -291,55 +291,25 @@ class Client:
@asyncio.coroutine
def _login_1(self, token, **kwargs):
log.info('logging in using static token')
self.token = token
self.email = None
if kwargs.pop('bot', True):
self.headers['authorization'] = 'Bot ' + self.token
else:
self.headers['authorization'] = self.token
resp = yield from self.session.get(endpoints.ME, headers=self.headers)
yield from resp.release()
log.debug(request_logging_format.format(method='GET', response=resp))
if resp.status != 200:
if resp.status == 401:
raise LoginFailure('Improper token has been passed.')
else:
raise HTTPException(resp, None)
log.info('token auth returned status code {}'.format(resp.status))
yield from self.http.static_login(token, bot=kwargs.pop('bot', True))
self._is_logged_in.set()
@asyncio.coroutine
def _login_2(self, email, password, **kwargs):
# attempt to read the token from cache
if self.cache_auth:
yield from self._login_via_cache(email, password)
if self.is_logged_in:
token = self._get_cache_token()
try:
self.http.static_login(token, bot=False)
except:
log.info('cache auth token is out of date')
else:
self._is_logged_in.set()
return
payload = {
'email': email,
'password': password
}
data = utils.to_json(payload)
resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=resp))
if resp.status != 200:
yield from resp.release()
if resp.status == 400:
raise LoginFailure('Improper credentials have been passed.')
else:
raise HTTPException(resp, None)
log.info('logging in returned status code {}'.format(resp.status))
yield from self.http.email_login(email, password)
self.email = email
body = yield from resp.json(encoding='utf-8')
self.token = body['token']
self.headers['authorization'] = self.token
self._is_logged_in.set()
# since we went through all this trouble
@ -395,12 +365,10 @@ class Client:
def logout(self):
"""|coro|
Logs out of Discord and closes all connections."""
response = yield from self.session.post(endpoints.LOGOUT, headers=self.headers)
yield from response.release()
Logs out of Discord and closes all connections.
"""
yield from self.close()
self._is_logged_in.clear()
log.debug(request_logging_format.format(method='POST', response=response))
@asyncio.coroutine
def connect(self):
@ -453,7 +421,7 @@ class Client:
yield from self.ws.close()
yield from self.session.close()
yield from self.http.close()
self._closed.set()
self._is_ready.clear()
@ -774,43 +742,11 @@ class Client:
if not isinstance(user, User):
raise InvalidArgument('user argument must be a User')
payload = {
'recipient_id': user.id
}
url = '{}/channels'.format(endpoints.ME)
r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json(encoding='utf-8')
log.debug(request_success_log.format(response=r, json=payload, data=data))
data = yield from self.http.start_private_message(user.id)
channel = PrivateChannel(id=data['id'], user=user)
self.connection._add_private_channel(channel)
return channel
@asyncio.coroutine
def _retry_helper(self, name, *args, retries=0, **kwargs):
req_kwargs = {'headers': self.headers}
req_kwargs.update(kwargs)
resp = yield from self.session.request(*args, **req_kwargs)
tmp = request_logging_format.format(method=resp.method, response=resp)
log_fmt = 'In {}, {}'.format(name, tmp)
log.debug(log_fmt)
if resp.status == 502 and retries < 5:
# retry the 502 request unconditionally
log.info('Retrying the 502 request to ' + name)
yield from asyncio.sleep(retries + 1)
return (yield from self._retry_helper(name, *args, retries=retries + 1, **kwargs))
if resp.status == 429:
retry = float(resp.headers['Retry-After']) / 1000.0
yield from resp.release()
yield from asyncio.sleep(retry)
return (yield from self._retry_helper(name, *args, retries=retries, **kwargs))
return resp
@asyncio.coroutine
def send_message(self, destination, content, *, tts=False):
"""|coro|
@ -858,23 +794,11 @@ class Client:
The message that was sent.
"""
channel_id = yield from self._resolve_destination(destination)
channel_id, guild_id = yield from self._resolve_destination(destination)
content = str(content)
url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id)
payload = {
'content': content,
'nonce': random_integer(-2**63, 2**63 - 1)
}
if tts:
payload['tts'] = True
resp = yield from self._retry_helper('send_message', 'POST', url, data=utils.to_json(payload))
yield from utils._verify_successful_response(resp)
data = yield from resp.json(encoding='utf-8')
log.debug(request_success_log.format(response=resp, json=payload, data=data))
data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts)
channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data)
return message
@ -895,14 +819,8 @@ class Client:
The location to send the typing update.
"""
channel_id = yield from self._resolve_destination(destination)
url = '{base}/{id}/typing'.format(base=endpoints.CHANNELS, id=channel_id)
response = yield from self.session.post(url, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
channel_id, guild_id = yield from self._resolve_destination(destination)
yield from self.http.send_typing(channel_id)
@asyncio.coroutine
def send_file(self, destination, fp, *, filename=None, content=None, tts=False):
@ -951,34 +869,18 @@ class Client:
The message sent.
"""
channel_id = yield from self._resolve_destination(destination)
url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id)
form = aiohttp.FormData()
if content is not None:
form.add_field('content', str(content))
form.add_field('tts', 'true' if tts else 'false')
# we don't want the content-type json in this request
headers = self.headers.copy()
headers.pop('content-type', None)
channel_id, guild_id = yield from self._resolve_destination(destination)
try:
# attempt to open the file and send the request
with open(fp, 'rb') as f:
form.add_field('file', f, filename=filename, content_type='application/octet-stream')
response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers)
buffer = f.read()
if filename is None:
filename = fp
except TypeError:
form.add_field('file', fp, filename=filename, content_type='application/octet-stream')
response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers)
log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response)
data = yield from response.json(encoding='utf-8')
msg = 'POST {0.url} returned {0.status} with {1} response'
log.debug(msg.format(response, data))
buffer = fp
data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id,
filename=filename, content=content, tts=tts)
channel = self.get_channel(data.get('channel_id'))
message = Message(channel=channel, **data)
return message
@ -1004,12 +906,8 @@ class Client:
HTTPException
Deleting the message failed.
"""
url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, message.channel.id, message.id)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
channel = message.channel
yield from self.http.delete_message(channel.id, message.id, channel.server.id)
@asyncio.coroutine
def delete_messages(self, messages):
@ -1045,16 +943,9 @@ class Client:
if len(messages) > 100 or len(messages) < 2:
raise ClientException('Can only delete messages in the range of [2, 100]')
channel_id = messages[0].channel.id
url = '{0}/{1}/messages/bulk_delete'.format(endpoints.CHANNELS, channel_id)
payload = {
'messages': [m.id for m in messages]
}
response = yield from self.session.post(url, headers=self.headers, data=utils.to_json(payload))
log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
channel = messages[0].channel
message_ids = [m.id for m in messages]
yield from self.http.delete_messages(channel.id, message_ids, channel.server.id)
@asyncio.coroutine
def purge_from(self, channel, *, limit=100, check=None, before=None, after=None):
@ -1179,19 +1070,9 @@ class Client:
channel = message.channel
content = str(new_content)
url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, channel.id, message.id)
payload = {
'content': content
}
response = yield from self._retry_helper('edit_message', 'PATCH', url, data=utils.to_json(payload))
log.debug(request_logging_format.format(method='PATCH', response=response))
yield from utils._verify_successful_response(response)
data = yield from response.json(encoding='utf-8')
log.debug(request_success_log.format(response=response, json=payload, data=data))
data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=channel.server.id)
return Message(channel=channel, **data)
@asyncio.coroutine
def _logs_from(self, channel, limit=100, before=None, after=None):
"""|coro|
@ -1242,21 +1123,7 @@ class Client:
if message.author == client.user:
counter += 1
"""
url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id)
params = {
'limit': limit
}
if before:
params['before'] = before.id
if after:
params['after'] = after.id
response = yield from self.session.get(url, params=params, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=response))
yield from utils._verify_successful_response(response)
messages = yield from response.json(encoding='utf-8')
return messages
return self.http.logs_from(channel.id, limit, before=before, after=after)
if PY35:
def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False):
@ -1356,12 +1223,7 @@ class Client:
HTTPException
Kicking failed.
"""
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.kick(member.id, member.server.id)
@asyncio.coroutine
def ban(self, member, delete_message_days=1):
@ -1390,16 +1252,7 @@ class Client:
HTTPException
Banning failed.
"""
params = {
'delete-message-days': delete_message_days
}
url = '{0}/{1.server.id}/bans/{1.id}'.format(endpoints.SERVERS, member)
response = yield from self.session.put(url, params=params, headers=self.headers)
log.debug(request_logging_format.format(method='PUT', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.ban(member.id, member.server.id, delete_message_days)
@asyncio.coroutine
def unban(self, server, user):
@ -1421,12 +1274,7 @@ class Client:
HTTPException
Unbanning failed.
"""
url = '{0}/{1.id}/bans/{2.id}'.format(endpoints.SERVERS, server, user)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.unban(user.id, server.id)
@asyncio.coroutine
def server_voice_state(self, member, *, mute=False, deafen=False):
@ -1456,17 +1304,7 @@ class Client:
HTTPException
The operation failed.
"""
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
payload = {
'mute': mute,
'deaf': deafen
}
response = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.server_voice_state(member.id, member.server.id, mute=mute, deafen=deafen)
@asyncio.coroutine
def edit_profile(self, password=None, **fields):
@ -1527,30 +1365,21 @@ class Client:
if not_bot_account and password is None:
raise ClientException('Password is required for non-bot accounts.')
payload = {
args = {
'password': password,
'username': fields.get('username', self.user.name),
'avatar': avatar
}
if not_bot_account:
payload['email'] = fields.get('email', self.email)
args['email'] = fields.get('email', self.email)
if 'new_password' in fields:
payload['new_password'] = fields['new_password']
r = yield from self.session.patch(endpoints.ME, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json(encoding='utf-8')
log.debug(request_success_log.format(response=r, json=payload, data=data))
args['new_password'] = fields['new_password']
yield from self.http.edit_profile(**args)
if not_bot_account:
self.token = data['token']
self.email = data['email']
self.headers['authorization'] = self.token
if self.cache_auth:
self._update_cache(self.email, password)
@ -1608,24 +1437,12 @@ class Client:
Changing the nickname failed.
"""
nickname = nickname if nickname else ''
if member == self.user:
fmt = '{0}/{1.server.id}/members/@me/nick'
yield from self.http.change_my_nickname(member.server.id, nickname)
else:
fmt = '{0}/{1.server.id}/members/{1.id}'
url = fmt.format(endpoints.SERVERS, member)
payload = {
# oddly enough, this endpoint requires '' to clear the nickname
# instead of the more consistent 'null', this might change in the
# future, or not.
'nick': nickname if nickname else ''
}
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
yield from r.release()
yield from self.http.change_nickname(member.server.id, member.id, nickname)
# Channel management
@ -1662,26 +1479,7 @@ class Client:
Editing the channel failed.
"""
url = '{0}/{1.id}'.format(endpoints.CHANNELS, channel)
payload = {
'name': options.get('name', channel.name),
'topic': options.get('topic', channel.topic),
}
user_limit = options.get('user_limit')
if user_limit is not None:
payload['user_limit'] = user_limit
bitrate = options.get('bitrate')
if bitrate is not None:
payload['bitrate'] = bitrate
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json(encoding='utf-8')
log.debug(request_success_log.format(response=r, json=payload, data=data))
yield from self.http.edit_channel(channel.id, **options)
@asyncio.coroutine
def move_channel(self, channel, position):
@ -1735,13 +1533,7 @@ class Client:
channels.insert(position, channel)
payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)]
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
yield from r.release()
log.debug(request_success_log.format(json=payload, response=r, data={}))
yield from self.http.patch(url, json=payload, bucket='move_channel')
@asyncio.coroutine
def create_channel(self, server, name, type=None):
@ -1779,18 +1571,7 @@ class Client:
if type is None:
type = ChannelType.text
payload = {
'name': name,
'type': str(type)
}
url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server)
response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response)
data = yield from response.json(encoding='utf-8')
log.debug(request_success_log.format(response=response, data=data, json=payload))
data = yield from self.http.create_channel(server.id, name, str(type))
channel = Channel(server=server, **data)
return channel
@ -1817,12 +1598,7 @@ class Client:
HTTPException
Deleting the channel failed.
"""
url = '{}/{}'.format(endpoints.CHANNELS, channel.id)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.delete_channel(channel.id)
# Server management
@ -1847,12 +1623,7 @@ class Client:
HTTPException
If leaving the server failed.
"""
url = '{}/@me/guilds/{.id}'.format(endpoints.USERS, server)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.leave_server(server.id)
@asyncio.coroutine
def delete_server(self, server):
@ -1874,11 +1645,7 @@ class Client:
You do not have permissions to delete the server.
"""
url = '{0}/{1.id}'.format(endpoints.SERVERS, server)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.delete_server(server.id)
@asyncio.coroutine
def create_server(self, name, region=None, icon=None):
@ -1918,17 +1685,7 @@ class Client:
else:
region = region.name
payload = {
'icon': icon,
'name': name,
'region': region
}
r = yield from self.session.post(endpoints.SERVERS, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json(encoding='utf-8')
log.debug(request_success_log.format(response=r, json=payload, data=data))
data = yield from self.http.create_server(name, region, icon)
return Server(**data)
@asyncio.coroutine
@ -1984,30 +1741,18 @@ class Client:
else:
icon = None
payload = {
'region': str(fields.get('region', server.region)),
'afk_timeout': fields.get('afk_timeout', server.afk_timeout),
'icon': icon,
'name': fields.get('name', server.name),
}
afk_channel = fields.get('afk_channel')
if afk_channel is None:
afk_channel = server.afk_channel
payload['afk_channel'] = getattr(afk_channel, 'id', None)
fields['icon'] = icon
if 'afk_channel' in fields:
fields['afk_channel_id'] = fields['afk_channel'].id
if 'owner' in fields:
if server.owner != server.me:
raise InvalidArgument('To transfer ownership you must be the owner of the server.')
payload['owner_id'] = fields['owner'].id
fields['owner_id'] = fields['owner'].id
yield from self.http.edit_server(server.id, **fields)
url = '{0}/{1.id}'.format(endpoints.SERVERS, server)
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
yield from r.release()
@asyncio.coroutine
def get_bans(self, server):
@ -2036,11 +1781,7 @@ class Client:
A list of :class:`User` that have been banned.
"""
url = '{0}/{1.id}/bans'.format(endpoints.SERVERS, server)
resp = yield from self.session.get(url, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=resp))
yield from utils._verify_successful_response(resp)
data = yield from resp.json(encoding='utf-8')
data = yield from self.http.get_bans(server.id)
return [User(**user['user']) for user in data]
# Invite management
@ -2092,20 +1833,7 @@ class Client:
The invite that was created.
"""
payload = {
'max_age': options.get('max_age', 0),
'max_uses': options.get('max_uses', 0),
'temporary': options.get('temporary', False),
'xkcdpass': options.get('xkcd', False)
}
url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination)
response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response)
data = yield from response.json(encoding='utf-8')
log.debug(request_success_log.format(json=payload, response=response, data=data))
data = yield from self.http.create_invite(destination.id, **options)
self._fill_invite_data(data)
return Invite(**data)
@ -2139,12 +1867,8 @@ class Client:
The invite from the URL/ID.
"""
destination = self._resolve_invite(url)
rurl = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
response = yield from self.session.get(rurl, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=response))
yield from utils._verify_successful_response(response)
data = yield from response.json(encoding='utf-8')
invite_id = self._resolve_invite(url)
data = yield from self.http.get_invite(invite_id)
self._fill_invite_data(data)
return Invite(**data)
@ -2174,11 +1898,7 @@ class Client:
The list of invites that are currently active.
"""
url = '{0}/{1.id}/invites'.format(endpoints.SERVERS, server)
resp = yield from self.session.get(url, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=resp))
yield from utils._verify_successful_response(resp)
data = yield from resp.json(encoding='utf-8')
data = yield from self.http.invites_from(server.id)
result = []
for invite in data:
channel = server.get_channel(invite['channel']['id'])
@ -2210,12 +1930,8 @@ class Client:
The invite is invalid or expired.
"""
destination = self._resolve_invite(invite)
url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
response = yield from self.session.post(url, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
invite_id = self._resolve_invite(invite)
yield from self.http.accept_invite(invite_id)
@asyncio.coroutine
def delete_invite(self, invite):
@ -2241,12 +1957,8 @@ class Client:
Revoking the invite failed.
"""
destination = self._resolve_invite(invite)
url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
invite_id = self._resolve_invite(invite)
yield from self.http.delete_invite(invite_id)
# Role management
@ -2298,13 +2010,7 @@ class Client:
roles.append(role.id)
payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json()
log.debug(request_success_log.format(json=payload, response=r, data=data))
yield from self.http.patch(url, json=payload, bucket='move_role')
@asyncio.coroutine
def edit_role(self, server, role, **fields):
@ -2345,11 +2051,6 @@ class Client:
Editing the role failed.
"""
url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role)
color = fields.get('color')
if color is None:
color = fields.get('colour', role.colour)
payload = {
'name': fields.get('name', role.name),
'permissions': fields.get('permissions', role.permissions).value,
@ -2358,12 +2059,7 @@ class Client:
'mentionable': fields.get('mentionable', role.mentionable)
}
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json(encoding='utf-8')
log.debug(request_success_log.format(json=payload, response=r, data=data))
yield from self.http.edit_role(server.id, role.id, **payload)
@asyncio.coroutine
def delete_role(self, server, role):
@ -2386,24 +2082,11 @@ class Client:
Deleting the role failed.
"""
url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.delete_role(server.id, role.id)
@asyncio.coroutine
def _replace_roles(self, member, roles):
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
payload = {
'roles': roles
}
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r)
yield from r.release()
yield from self.http.replace_roles(member.id, member.server.id, roles)
@asyncio.coroutine
def add_roles(self, member, *roles):
@ -2521,12 +2204,7 @@ class Client:
is stored in cache.
"""
url = '{0}/{1.id}/roles'.format(endpoints.SERVERS, server)
r = yield from self.session.post(url, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r)
data = yield from r.json(encoding='utf-8')
data = yield from self.http.create_role(server.id)
role = Role(server=server, **data)
# we have to call edit because you can't pass a payload to the
@ -2581,8 +2259,6 @@ class Client:
or the target type was not :class:`Role` or :class:`Member`.
"""
url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target)
allow = Permissions.none() if allow is None else allow
deny = Permissions.none() if deny is None else deny
@ -2592,23 +2268,14 @@ class Client:
deny = deny.value
allow = allow.value
payload = {
'id': target.id,
'allow': allow,
'deny': deny
}
if isinstance(target, Member):
payload['type'] = 'member'
perm_type = 'member'
elif isinstance(target, Role):
payload['type'] = 'role'
perm_type = 'role'
else:
raise InvalidArgument('target parameter must be either discord.Member or discord.Role')
r = yield from self.session.put(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PUT', response=r))
yield from utils._verify_successful_response(r)
yield from r.release()
yield from self.http.edit_channel_permissions(channel.id, target.id, allow, deny, perm_type)
@asyncio.coroutine
def delete_channel_permissions(self, channel, target):
@ -2637,12 +2304,7 @@ class Client:
HTTPException
Deleting channel specific permissions failed.
"""
url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target)
response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.delete_channel_permissions(channel.id, target.id)
# Voice management
@ -2676,18 +2338,10 @@ class Client:
You do not have permissions to move the member.
"""
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
if getattr(channel, 'type', ChannelType.text) != ChannelType.voice:
raise InvalidArgument('The channel provided must be a voice channel.')
payload = utils.to_json({
'channel_id': channel.id
})
response = yield from self.session.patch(url, data=payload, headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=response))
yield from utils._verify_successful_response(response)
yield from response.release()
yield from self.http.move_member(member.id, member.server.id, channel.id)
@asyncio.coroutine
def join_voice_channel(self, channel):
@ -2817,10 +2471,7 @@ class Client:
HTTPException
Retrieving the information failed somehow.
"""
url = '{}/@me'.format(endpoints.APPLICATIONS)
resp = yield from self.session.get(url, headers=self.headers)
yield from utils._verify_successful_response(resp)
data = yield from resp.json()
data = yield from self.http.application_info()
return AppInfo(id=data['id'], name=data['name'],
description=data['description'], icon=data['icon'])

38
discord/gateway.py

@ -40,7 +40,7 @@ import struct
log = logging.getLogger(__name__)
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
__all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
'KeepAliveHandler', 'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket', 'ResumeWebSocket' ]
@ -97,36 +97,6 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
'd': int(time.time() * 1000)
}
@asyncio.coroutine
def get_gateway(token, *, loop=None):
"""Returns the gateway URL for connecting to the WebSocket.
Parameters
-----------
token : str
The discord authentication token.
loop
The event loop.
Raises
------
GatewayNotFound
When the gateway is not returned gracefully.
"""
headers = {
'authorization': token,
'content-type': 'application/json'
}
with aiohttp.ClientSession(loop=loop) as session:
resp = yield from session.get(endpoints.GATEWAY, headers=headers)
if resp.status != 200:
yield from resp.release()
raise GatewayNotFound()
data = yield from resp.json(encoding='utf-8')
return data.get('url') + '?encoding=json&v=4'
class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
"""Implements a WebSocket for Discord's gateway v4.
@ -190,11 +160,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
This is for internal use only.
"""
gateway = yield from get_gateway(client.token, loop=client.loop)
gateway = yield from client.http.get_gateway()
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
# dynamically add attributes needed
ws.token = client.token
ws.token = client.http.token
ws._connection = client.connection
ws._dispatch = client.dispatch
ws.gateway = gateway
@ -505,7 +475,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
'server_id': client.guild_id,
'user_id': client.user.id,
'session_id': client.session_id,
'token': client.token
'token': client.http.token
}
}

484
discord/http.py

@ -0,0 +1,484 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-2016 Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import aiohttp
import asyncio
import json
import sys
import logging
import io
import inspect
import weakref
from random import randint as random_integer
log = logging.getLogger(__name__)
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound
from . import utils, __version__
@asyncio.coroutine
def json_or_text(response):
text = yield from response.text(encoding='utf-8')
if response.headers['content-type'] == 'application/json':
return json.loads(text)
return text
def _func_():
# emulate __func__ from C++
return inspect.currentframe().f_back.f_code.co_name
class HTTPClient:
"""Represents an HTTP client sending HTTP requests to the Discord API."""
BASE = 'https://discordapp.com'
API_BASE = BASE + '/api'
GATEWAY = API_BASE + '/gateway'
USERS = API_BASE + '/users'
ME = USERS + '/@me'
REGISTER = API_BASE + '/auth/register'
LOGIN = API_BASE + '/auth/login'
LOGOUT = API_BASE + '/auth/logout'
GUILDS = API_BASE + '/guilds'
CHANNELS = API_BASE + '/channels'
APPLICATIONS = API_BASE + '/oauth2/applications'
SUCCESS_LOG = '{method} {url} with {json} has received {text}'
REQUEST_LOG = '{method} {url} has returned {status}'
def __init__(self, connector=None, *, loop=None):
self.loop = asyncio.get_event_loop() if loop is None else loop
self.connector = connector
self.session = aiohttp.ClientSession(connector=connector, loop=self.loop)
self._locks = weakref.WeakValueDictionary()
self.token = None
self.bot_token = False
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
@asyncio.coroutine
def request(self, method, url, *, bucket=None, **kwargs):
lock = self._locks.get(bucket)
if lock is None:
lock = asyncio.Lock(loop=self.loop)
if bucket is not None:
self._locks[bucket] = lock
# header creation
headers = {
'User-Agent': self.user_agent,
}
if self.token is not None:
headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token
# some checking if it's a JSON request
if 'json' in kwargs:
headers['Content-Type'] = 'application/json'
kwargs['data'] = utils.to_json(kwargs.pop('json'))
kwargs['headers'] = headers
with (yield from lock):
for tries in range(5):
r = yield from self.session.request(method, url, **kwargs)
log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status))
try:
# even errors have text involved in them so this is safe to call
data = yield from json_or_text(r)
# the request was successful so just return the text/json
if 300 > r.status >= 200:
log.debug(self.SUCCESS_LOG.format(method=method, url=url,
json=kwargs.get('data'), text=data))
return data
# we are being rate limited
if r.status == 429:
fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"'
# sleep a bit
retry_after = data['retry_after'] / 1000.0
log.info(fmt.format(retry_after, bucket))
yield from asyncio.sleep(retry_after)
continue
# we've received a 502, unconditional retry
if r.status == 502 and tries <= 5:
yield from asyncio.sleep(1 + tries * 2)
continue
# the usual error cases
if r.status == 403:
raise Forbidden(r, data)
elif r.status == 404:
raise NotFound(r, data)
else:
raise HTTPException(r, data)
finally:
# clean-up just in case
yield from r.release()
def get(self, *args, **kwargs):
return self.request('GET', *args, **kwargs)
def put(self, *args, **kwargs):
return self.request('PUT', *args, **kwargs)
def patch(self, *args, **kwargs):
return self.request('PATCH', *args, **kwargs)
def delete(self, *args, **kwargs):
return self.request('DELETE', *args, **kwargs)
def post(self, *args, **kwargs):
return self.request('POST', *args, **kwargs)
# state management
@asyncio.coroutine
def close(self):
yield from self.session.close()
def recreate(self):
self.session = aiohttp.ClientSession(self.connector, loop=self.loop)
def _token(self, token, *, bot=True):
self.token = token
self.bot_token = bot
# login management
@asyncio.coroutine
def email_login(self, email, password):
payload = {
'email': email,
'password': password
}
try:
data = yield from self.post(self.LOGIN, json=payload, bucket=_func_())
except HTTPException as e:
if e.response.status == 400:
raise LoginFailure('Improper credentials have been passed.') from e
raise
self._token(data['token'], bot=False)
return data
@asyncio.coroutine
def static_login(self, token, *, bot):
old_state = (self.token, self.bot_token)
self._token(token, bot=bot)
try:
data = yield from self.get(self.ME)
except HTTPException as e:
self._token(*old_state)
if e.response.status == 401:
raise LoginFailure('Improper token has been passed.') from e
raise e
return data
def logout(self):
return self.post(self.LOGOUT, bucket=_func_())
# Message management
def start_private_message(self, user_id):
payload = {
'recipient_id': user_id
}
return self.post(self.ME + '/channels', json=payload, bucket=_func_())
def send_message(self, channel_id, content, *, guild_id=None, tts=False):
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
payload = {
'content': str(content),
'nonce': random_integer(-2**63, 2**63 - 1)
}
if tts:
payload['tts'] = True
return self.post(url, json=payload, bucket='messages:' + str(guild_id))
def send_typing(self, channel_id):
url = '{0.CHANNELS}/{1}/typing'.format(self, channel_id)
return self.post(url, bucket=_func_())
def send_file(self, channel_id, buffer, *, guild_id=None, filename=None, content=None, tts=False):
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
form = aiohttp.FormData()
if content is not None:
form.add_field('content', str(content))
form.add_field('tts', 'true' if tts else 'false')
form.add_field('file', io.BytesIO(buffer), filename=filename, content_type='application/octet-stream')
return self.post(url, data=form, bucket='messages:' + str(guild_id))
def delete_message(self, channel_id, message_id, guild_id=None):
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
bucket = '{}:{}'.format(_func_(), guild_id)
return self.delete(url, bucket=bucket)
def delete_messages(self, channel_id, message_ids, guild_id=None):
url = '{0.CHANNELS}/{1}/messages/bulk_delete'.format(self, channel_id)
payload = {
'messages': message_ids
}
bucket = '{}:{}'.format(_func_(), guild_id)
return self.post(url, json=payload, bucket=bucket)
def edit_message(self, message_id, channel_id, content, *, guild_id=None):
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
payload = {
'content': str(content)
}
return self.patch(url, json=payload, bucket='messages:' + str(guild_id))
def logs_from(self, channel_id, limit, before=None, after=None):
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
params = {
'limit': limit
}
if before:
params['before'] = before
if after:
params['after'] = after
return self.get(url, params=params, bucket=_func_())
# Member management
def kick(self, user_id, guild_id):
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
return self.delete(url, bucket=_func_())
def ban(self, user_id, guild_id, delete_message_days=1):
url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id)
params = {
'delete-message-days': delete_message_days
}
return self.put(url, params=params, bucket=_func_())
def unban(self, user_id, guild_id):
url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id)
return self.delete(url, bucket=_func_())
def server_voice_state(self, user_id, guild_id, *, mute=False, deafen=False):
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
payload = {
'mute': mute,
'deafen': deafen
}
return self.patch(url, json=payload, bucket='members:' + str(guild_id))
def edit_profile(self, password, username, avatar, **fields):
payload = {
'password': password,
'username': username,
'avatar': avatar
}
if 'email' in fields:
payload['email'] = fields['email']
if 'new_password' in fields:
payload['new_password'] = fields['new_password']
return self.patch(self.ME, json=payload, bucket=_func_())
def change_my_nickname(self, guild_id, nickname):
url = '{0.GUILDS}/{1}/members/@me/nick'.format(self, guild_id)
payload = {
'nick': nickname
}
bucket = '{}:{}'.format(_func_(), guild_id)
return self.patch(url, json=payload, bucket=bucket)
def change_nickname(self, guild_id, user_id, nickname):
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
payload = {
'nick': nickname
}
bucket = '{}:{}'.format(_func_(), guild_id)
return self.patch(url, json=payload, bucket=bucket)
# Channel management
def edit_channel(self, channel_id, **options):
url = '{0.CHANNELS}/{1}'.format(self, channel_id)
valid_keys = ('name', 'topic', 'bitrate', 'user_limit')
payload = {
k: v for k, v in options.items() if k in valid_keys
}
return self.patch(url, json=payload, bucket=_func_())
def create_channel(self, guild_id, name, channe_type):
url = '{0.GUILDS}/{1}/channels'.format(self, guild_id)
payload = {
'name': name,
'type': channe_type
}
return self.post(url, json=payload, bucket=_func_())
def delete_channel(self, channel_id):
url = '{0.CHANNELS}/{1}'.format(self, channel_id)
return self.delete(url, bucket=_func_())
# Server management
def leave_server(self, guild_id):
url = '{0.USERS}/@me/guilds/{1}'.format(self, guild_id)
return self.delete(url, bucket=_func_())
def delete_server(self, guild_id):
url = '{0.GUILDS}/{1}'.format(self, guild_id)
return self.delete(url, bucket=_func_())
def create_server(self, name, region, icon):
payload = {
'name': name,
'icon': icon,
'region': region
}
return self.post(self.GUILDS, json=payload, bucket=_func_())
def edit_server(self, guild_id, **fields):
valid_keys = ('name', 'region', 'icon', 'afk_timeout', 'owner_id',
'afk_channel_id', 'splash', 'verification_level')
payload = {
k: v for k, v in fields.items() if k in valid_keys
}
url = '{0.GUILDS}/{1}'.format(self, guild_id)
return self.patch(url, json=payload, bucket=_func_())
def get_bans(self, guild_id):
url = '{0.GUILDS}/{1}/bans'.format(self, guild_id)
return self.get(url, bucket=_func_())
# Invite management
def create_invite(self, channel_id, **options):
url = '{0.CHANNELS}/{1}/invites'.format(self, channel_id)
payload = {
'max_age': options.get('max_age', 0),
'max_uses': options.get('max_uses', 0),
'temporary': options.get('temporary', False),
'xkcdpass': options.get('xkcd', False)
}
return self.post(url, json=payload, bucket=_func_())
def get_invite(self, invite_id):
url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
return self.get(url, bucket=_func_())
def invites_from(self, guild_id):
url = '{0.GUILDS}/{1}/invites'.format(self, guild_id)
return self.get(url, bucket=_func_())
def accept_invite(self, invite_id):
url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
return self.post(url, bucket=_func_())
def delete_invite(self, invite_id):
url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
return self.delete(url, bucket=_func_())
# Role management
def edit_role(self, guild_id, role_id, **fields):
url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id)
valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable')
payload = {
k: v for k, v in fields.items() if k in valid_keys
}
return self.patch(url, json=payload, bucket='roles:' + str(guild_id))
def delete_role(self, guild_id, role_id):
url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id)
return self.delete(url, bucket=_func_())
def replace_roles(self, user_id, guild_id, role_ids):
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
payload = {
'roles': role_ids
}
return self.patch(url, json=payload, bucket='members:' + str(guild_id))
def create_role(self, guild_id):
url = '{0.GUILDS}/{1}/roles'.format(self, guild_id)
return self.post(url, bucket=_func_())
def edit_channel_permissions(self, channel_id, target, allow, deny, type):
url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target)
payload = {
'id': target,
'allow': allow,
'deny': deny,
'type': type
}
return self.put(url, json=payload, bucket=_func_())
def delete_channel_permissions(self, channel_id, target):
url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target)
return self.delete(url, bucket=_func_())
# Voice management
def move_member(self, user_id, guild_id, channel_id):
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
payload = {
'channel_id': channel_id
}
return self.patch(url, json=payload, bucket='members:' + str(guild_id))
# Misc
def application_info(self):
url = '{0.APPLICATIONS}/@me'.format(self)
return self.get(url, bucket=_func_())
@asyncio.coroutine
def get_gateway(self):
try:
data = yield from self.get(self.GATEWAY, bucket=_func_())
except HTTPException as e:
raise GatewayNotFound() from e
return data.get('url') + '?encoding=json&v=4'
Loading…
Cancel
Save