Browse Source

Switch to using ClientSession objects for aiohttp v0.21

pull/108/head
Rapptz 9 years ago
parent
commit
ff14fa0fe8
  1. 84
      discord/client.py

84
discord/client.py

@ -124,7 +124,7 @@ class Client:
self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop) self.connection = ConnectionState(self.dispatch, max_messages, loop=self.loop)
# Blame React for this # Blame Jake for this
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
self.headers = { self.headers = {
@ -132,6 +132,8 @@ class Client:
'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__) 'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__)
} }
self.session = aiohttp.ClientSession(loop=self.loop)
self._closed = asyncio.Event(loop=self.loop) self._closed = asyncio.Event(loop=self.loop)
self._is_logged_in = asyncio.Event(loop=self.loop) self._is_logged_in = asyncio.Event(loop=self.loop)
self._is_ready = asyncio.Event(loop=self.loop) self._is_ready = asyncio.Event(loop=self.loop)
@ -269,7 +271,7 @@ class Client:
@asyncio.coroutine @asyncio.coroutine
def _get_gateway(self): def _get_gateway(self):
resp = yield from aiohttp.get(endpoints.GATEWAY, headers=self.headers, loop=self.loop) resp = yield from self.session.get(endpoints.GATEWAY, headers=self.headers)
if resp.status != 200: if resp.status != 200:
raise GatewayNotFound() raise GatewayNotFound()
data = yield from resp.json() data = yield from resp.json()
@ -491,7 +493,7 @@ class Client:
} }
data = utils.to_json(payload) data = utils.to_json(payload)
resp = yield from aiohttp.post(endpoints.LOGIN, data=data, headers=self.headers, loop=self.loop) resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=resp)) log.debug(request_logging_format.format(method='POST', response=resp))
if resp.status != 200: if resp.status != 200:
yield from resp.release() yield from resp.release()
@ -518,7 +520,7 @@ class Client:
"""|coro| """|coro|
Logs out of Discord and closes all connections.""" Logs out of Discord and closes all connections."""
response = yield from aiohttp.post(endpoints.LOGOUT, headers=self.headers, loop=self.loop) response = yield from self.session.post(endpoints.LOGOUT, headers=self.headers)
yield from response.release() yield from response.release()
yield from self.close() yield from self.close()
self._is_logged_in.clear() self._is_logged_in.clear()
@ -528,14 +530,9 @@ class Client:
def connect(self): def connect(self):
"""|coro| """|coro|
Creates a websocket connection and connects to the websocket listen Creates a websocket connection and lets the websocket listen
to messages from discord. to messages from discord.
This function is implemented using a while loop in the background.
If you need to run this event listening in another thread then
you should run it in an executor or schedule the coroutine to
be executed later using ``loop.create_task``.
Raises Raises
------- -------
ClientException ClientException
@ -578,6 +575,8 @@ class Client:
if self.ws.open: if self.ws.open:
yield from self.ws.close() yield from self.ws.close()
yield from self.session.close()
self.keep_alive.cancel() self.keep_alive.cancel()
self._closed.set() self._closed.set()
self._is_ready.clear() self._is_ready.clear()
@ -901,7 +900,7 @@ class Client:
} }
url = '{}/channels'.format(endpoints.ME) url = '{}/channels'.format(endpoints.ME)
r = yield from aiohttp.post(url, data=utils.to_json(payload), headers=self.headers, loop=self.loop) r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r)) log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
data = yield from r.json() data = yield from r.json()
@ -912,7 +911,7 @@ class Client:
@asyncio.coroutine @asyncio.coroutine
def _rate_limit_helper(self, name, method, url, data): def _rate_limit_helper(self, name, method, url, data):
resp = yield from aiohttp.request(method, url, data=data, headers=self.headers, loop=self.loop) resp = yield from self.session.request(method, url, data=data, headers=self.headers)
tmp = request_logging_format.format(method=method, response=resp) tmp = request_logging_format.format(method=method, response=resp)
log_fmt = 'In {}, {}'.format(name, tmp) log_fmt = 'In {}, {}'.format(name, tmp)
log.debug(log_fmt) log.debug(log_fmt)
@ -1012,7 +1011,7 @@ class Client:
url = '{base}/{id}/typing'.format(base=endpoints.CHANNELS, id=channel_id) url = '{base}/{id}/typing'.format(base=endpoints.CHANNELS, id=channel_id)
response = yield from aiohttp.post(url, headers=self.headers, loop=self.loop) response = yield from self.session.post(url, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response)) log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1066,17 +1065,18 @@ class Client:
# we don't want the content-type json in this request # we don't want the content-type json in this request
headers = { headers = {
'authorization': self.token 'authorization': self.token,
'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__)
} }
try: try:
# attempt to open the file and send the request # attempt to open the file and send the request
with open(fp, 'rb') as f: with open(fp, 'rb') as f:
files.add_field('file', f, filename=filename, content_type='application/octet-stream') files.add_field('file', f, filename=filename, content_type='application/octet-stream')
response = yield from aiohttp.post(url, data=files, headers=headers, loop=self.loop) response = yield from self.session.post(url, data=files, headers=headers)
except TypeError: except TypeError:
files.add_field('file', fp, filename=filename, content_type='application/octet-stream') files.add_field('file', fp, filename=filename, content_type='application/octet-stream')
response = yield from aiohttp.post(url, data=files, headers=headers, loop=self.loop) response = yield from self.session.post(url, data=files, headers=headers)
log.debug(request_logging_format.format(method='POST', response=response)) log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
@ -1110,7 +1110,7 @@ class Client:
""" """
url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, message.channel.id, message.id) url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, message.channel.id, message.id)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1215,7 +1215,7 @@ class Client:
if after: if after:
params['after'] = after.id params['after'] = after.id
response = yield from aiohttp.get(url, params=params, headers=self.headers, loop=self.loop) response = yield from self.session.get(url, params=params, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=response)) log.debug(request_logging_format.format(method='GET', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
messages = yield from response.json() messages = yield from response.json()
@ -1311,7 +1311,7 @@ class Client:
""" """
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1349,7 +1349,7 @@ class Client:
} }
url = '{0}/{1.server.id}/bans/{1.id}'.format(endpoints.SERVERS, member) url = '{0}/{1.server.id}/bans/{1.id}'.format(endpoints.SERVERS, member)
response = yield from aiohttp.put(url, params=params, headers=self.headers, loop=self.loop) response = yield from self.session.put(url, params=params, headers=self.headers)
log.debug(request_logging_format.format(method='PUT', response=response)) log.debug(request_logging_format.format(method='PUT', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1376,7 +1376,7 @@ class Client:
""" """
url = '{0}/{1.id}/bans/{2.id}'.format(endpoints.SERVERS, server, user) url = '{0}/{1.id}/bans/{2.id}'.format(endpoints.SERVERS, server, user)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1416,7 +1416,7 @@ class Client:
'deaf': deafen 'deaf': deafen
} }
response = yield from aiohttp.patch(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) response = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=response)) log.debug(request_logging_format.format(method='PATCH', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1480,7 +1480,7 @@ class Client:
'avatar': avatar 'avatar': avatar
} }
r = yield from aiohttp.patch(endpoints.ME, headers=self.headers, data=utils.to_json(payload), loop=self.loop) r = yield from self.session.patch(endpoints.ME, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r)) log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
@ -1580,7 +1580,7 @@ class Client:
'position': options.get('position', channel.position) 'position': options.get('position', channel.position)
} }
r = yield from aiohttp.patch(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r)) log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
@ -1629,7 +1629,7 @@ class Client:
} }
url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server) url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server)
response = yield from aiohttp.post(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response)) log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
@ -1663,7 +1663,7 @@ class Client:
""" """
url = '{}/{}'.format(endpoints.CHANNELS, channel.id) url = '{}/{}'.format(endpoints.CHANNELS, channel.id)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1692,7 +1692,7 @@ class Client:
""" """
url = '{0}/{1.id}'.format(endpoints.SERVERS, server) url = '{0}/{1.id}'.format(endpoints.SERVERS, server)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -1741,7 +1741,7 @@ class Client:
'region': region 'region': region
} }
r = yield from aiohttp.post(endpoints.SERVERS, data=utils.to_json(payload), headers=self.headers, loop=self.loop) r = yield from self.session.post(endpoints.SERVERS, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r)) log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
data = yield from r.json() data = yield from r.json()
@ -1821,7 +1821,7 @@ class Client:
payload['owner_id'] = fields['owner'].id payload['owner_id'] = fields['owner'].id
url = '{0}/{1.id}'.format(endpoints.SERVERS, server) url = '{0}/{1.id}'.format(endpoints.SERVERS, server)
r = yield from aiohttp.patch(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r)) log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
yield from r.release() yield from r.release()
@ -1854,7 +1854,7 @@ class Client:
""" """
url = '{0}/{1.id}/bans'.format(endpoints.SERVERS, server) url = '{0}/{1.id}/bans'.format(endpoints.SERVERS, server)
resp = yield from aiohttp.get(url, headers=self.headers, loop=self.loop) resp = yield from self.session.get(url, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=resp)) log.debug(request_logging_format.format(method='GET', response=resp))
yield from utils._verify_successful_response(resp) yield from utils._verify_successful_response(resp)
data = yield from resp.json() data = yield from resp.json()
@ -1917,7 +1917,7 @@ class Client:
} }
url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination) url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination)
response = yield from aiohttp.post(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response)) log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
@ -1958,7 +1958,7 @@ class Client:
destination = self._resolve_invite(url) destination = self._resolve_invite(url)
rurl = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) rurl = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
response = yield from aiohttp.get(rurl, headers=self.headers, loop=self.loop) response = yield from self.session.get(rurl, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=response)) log.debug(request_logging_format.format(method='GET', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
data = yield from response.json() data = yield from response.json()
@ -1992,7 +1992,7 @@ class Client:
""" """
url = '{0}/{1.id}/invites'.format(endpoints.SERVERS, server) url = '{0}/{1.id}/invites'.format(endpoints.SERVERS, server)
resp = yield from aiohttp.get(url, headers=self.headers, loop=self.loop) resp = yield from self.session.get(url, headers=self.headers)
log.debug(request_logging_format.format(method='GET', response=resp)) log.debug(request_logging_format.format(method='GET', response=resp))
yield from utils._verify_successful_response(resp) yield from utils._verify_successful_response(resp)
data = yield from resp.json() data = yield from resp.json()
@ -2029,7 +2029,7 @@ class Client:
destination = self._resolve_invite(invite) destination = self._resolve_invite(invite)
url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
response = yield from aiohttp.post(url, headers=self.headers, loop=self.loop) response = yield from self.session.post(url, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=response)) log.debug(request_logging_format.format(method='POST', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -2060,7 +2060,7 @@ class Client:
destination = self._resolve_invite(invite) destination = self._resolve_invite(invite)
url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -2122,7 +2122,7 @@ class Client:
'hoist': fields.get('hoist', role.hoist) 'hoist': fields.get('hoist', role.hoist)
} }
r = yield from aiohttp.patch(url, data=utils.to_json(payload), headers=self.headers, loop=self.loop) r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r)) log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
@ -2153,7 +2153,7 @@ class Client:
""" """
url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role) url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -2166,7 +2166,7 @@ class Client:
'roles': roles 'roles': roles
} }
r = yield from aiohttp.patch(url, headers=self.headers, data=utils.to_json(payload), loop=self.loop) r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=r)) log.debug(request_logging_format.format(method='PATCH', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
yield from r.release() yield from r.release()
@ -2288,7 +2288,7 @@ class Client:
""" """
url = '{0}/{1.id}/roles'.format(endpoints.SERVERS, server) url = '{0}/{1.id}/roles'.format(endpoints.SERVERS, server)
r = yield from aiohttp.post(url, headers=self.headers, loop=self.loop) r = yield from self.session.post(url, headers=self.headers)
log.debug(request_logging_format.format(method='POST', response=r)) log.debug(request_logging_format.format(method='POST', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
@ -2372,7 +2372,7 @@ class Client:
else: else:
raise InvalidArgument('target parameter must be either discord.Member or discord.Role') raise InvalidArgument('target parameter must be either discord.Member or discord.Role')
r = yield from aiohttp.put(url, data=utils.to_json(payload), headers=self.headers, loop=self.loop) r = yield from self.session.put(url, data=utils.to_json(payload), headers=self.headers)
log.debug(request_logging_format.format(method='PUT', response=r)) log.debug(request_logging_format.format(method='PUT', response=r))
yield from utils._verify_successful_response(r) yield from utils._verify_successful_response(r)
yield from r.release() yield from r.release()
@ -2406,7 +2406,7 @@ class Client:
""" """
url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target) url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target)
response = yield from aiohttp.delete(url, headers=self.headers, loop=self.loop) response = yield from self.session.delete(url, headers=self.headers)
log.debug(request_logging_format.format(method='DELETE', response=response)) log.debug(request_logging_format.format(method='DELETE', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()
@ -2451,7 +2451,7 @@ class Client:
payload = utils.to_json({ payload = utils.to_json({
'channel_id': channel.id 'channel_id': channel.id
}) })
response = yield from aiohttp.patch(url, data=payload, headers=self.headers, loop=self.loop) response = yield from self.session.patch(url, data=payload, headers=self.headers)
log.debug(request_logging_format.format(method='PATCH', response=response)) log.debug(request_logging_format.format(method='PATCH', response=response))
yield from utils._verify_successful_response(response) yield from utils._verify_successful_response(response)
yield from response.release() yield from response.release()

Loading…
Cancel
Save