Browse Source

Capture Raw API Responses (#46)

* Add support for capturing and viewing API responses

* Just subclass list, naming changes, tests
pull/47/head
Andrei Zbikowski 8 years ago
committed by GitHub
parent
commit
9e9d6bb1b1
  1. 50
      disco/api/client.py
  2. 21
      disco/api/http.py
  3. 49
      disco/api/ratelimit.py
  4. 15
      examples/basic_plugin.py
  5. 18
      tests/api/client.py

50
disco/api/client.py

@ -2,6 +2,8 @@ import six
import json
import warnings
from contextlib import contextmanager
from gevent.local import local
from six.moves.urllib.parse import quote
from disco.api.http import Routes, HTTPClient, to_bytes
@ -29,6 +31,15 @@ def _reason_header(value):
return optional(**{'X-Audit-Log-Reason': quote(to_bytes(value)) if value else None})
class Responses(list):
def rate_limited_duration(self):
return sum([i.rate_limited_duration for i in self])
@property
def rate_limited(self):
return self.rate_limited_duration() != 0
class APIClient(LoggingClass):
"""
An abstraction over a :class:`disco.api.http.HTTPClient`, which composes
@ -56,7 +67,30 @@ class APIClient(LoggingClass):
super(APIClient, self).__init__()
self.client = client
self.http = HTTPClient(token)
self.http = HTTPClient(token, self._after_requests)
self._captures = local()
def _after_requests(self, response):
if not hasattr(self._captures, 'responses'):
return
self._captures.responses.append(response)
@contextmanager
def capture(self):
"""
Context manager which captures all requests made, returning a special
`Responses` list, which can be used to introspect raw API responses. This
method is a low-level utility which should only be used by experienced users.
"""
responses = Responses()
self._captures.responses = responses
try:
yield responses
finally:
delattr(self._captures, 'responses')
def gateway_get(self):
data = self.http(Routes.GATEWAY_GET).json()
@ -195,7 +229,10 @@ class APIClient(LoggingClass):
}, headers=_reason_header(reason))
def channels_permissions_delete(self, channel, permission, reason=None):
self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission), headers=_reason_header(reason))
self.http(
Routes.CHANNELS_PERMISSIONS_DELETE,
dict(channel=channel, permission=permission), headers=_reason_header(reason)
)
def channels_invites_list(self, channel):
r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel))
@ -247,7 +284,14 @@ class APIClient(LoggingClass):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[], reason=None):
def guilds_channels_create(self,
guild,
name,
channel_type,
bitrate=None,
user_limit=None,
permission_overwrites=[],
reason=None):
payload = {
'name': name,
'channel_type': channel_type,

21
disco/api/http.py

@ -132,6 +132,13 @@ class Routes(object):
WEBHOOKS_TOKEN_EXECUTE = (HTTPMethod.POST, WEBHOOKS + '/{token}')
class APIResponse(object):
def __init__(self):
self.response = None
self.exception = None
self.rate_limited_duration = 0
class APIException(Exception):
"""
Exception thrown when an HTTP-client level error occurs. Usually this will
@ -183,7 +190,7 @@ class HTTPClient(LoggingClass):
BASE_URL = 'https://discordapp.com/api/v7'
MAX_RETRIES = 5
def __init__(self, token):
def __init__(self, token, after_request=None):
super(HTTPClient, self).__init__()
py_version = '{}.{}.{}'.format(
@ -202,6 +209,7 @@ class HTTPClient(LoggingClass):
if token:
self.headers['Authorization'] = 'Bot ' + token
self.after_request = after_request
self.session = requests.Session()
def __call__(self, route, args=None, **kwargs):
@ -251,8 +259,10 @@ class HTTPClient(LoggingClass):
filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)}
bucket = (route[0].value, route[1].format(**filtered))
response = APIResponse()
# Possibly wait if we're rate limited
self.limiter.check(bucket)
response.rate_limited_duration = self.limiter.check(bucket)
self.log.debug('KW: %s', kwargs)
@ -261,6 +271,10 @@ class HTTPClient(LoggingClass):
self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params'))
r = self.session.request(route[0].value, url, **kwargs)
if self.after_request:
response.response = r
self.after_request(response)
# Update rate limiter
self.limiter.update(bucket, r)
@ -269,7 +283,8 @@ class HTTPClient(LoggingClass):
return r
elif r.status_code != 429 and 400 <= r.status_code < 500:
self.log.warning('Request failed with code %s: %s', r.status_code, r.content)
raise APIException(r)
response.exception = APIException(r)
raise response.exception
else:
if r.status_code == 429:
self.log.warning(

49
disco/api/ratelimit.py

@ -1,6 +1,7 @@
import time
import gevent
from disco.util.logging import LoggingClass
@ -76,18 +77,18 @@ class RouteState(LoggingClass):
"""
Waits until this route is no longer under a cooldown.
Parameters
----------
timeout : Optional[int]
A timeout (in seconds) after which we will give up waiting
Returns
-------
bool
False if the timeout period expired before the cooldown was finished.
float
The duration we waited for, in seconds or zero if we didn't have to
wait at all.
"""
return self.event.wait(timeout)
if self.event.is_set():
return 0
start = time.time()
self.event.wait()
return time.time() - start
def cooldown(self):
"""
@ -102,6 +103,7 @@ class RouteState(LoggingClass):
gevent.sleep(delay)
self.event.set()
self.event = None
return delay
class RateLimiter(LoggingClass):
@ -117,40 +119,37 @@ class RateLimiter(LoggingClass):
def __init__(self):
self.states = {}
def check(self, route, timeout=None):
def check(self, route):
"""
Checks whether a given route can be called. This function will return
immediately if no rate-limit cooldown is being imposed for the given
route, or will wait indefinitely (unless timeout is specified) until
the route is finished being cooled down. This function should be called
before making a request to the specified route.
route, or will wait indefinitely until the route is finished being
cooled down. This function should be called before making a request to
the specified route.
Parameters
----------
route : tuple(HTTPMethod, str)
The route that will be checked.
timeout : Optional[int]
A timeout after which we'll give up waiting for a route's cooldown
to expire, and immediately return.
Returns
-------
bool
False if the timeout period expired before the route finished cooling
down.
float
The number of seconds we had to wait for this rate limit, or zero
if no time was waited.
"""
return self._check(None, timeout) and self._check(route, timeout)
return self._check(None) + self._check(route)
def _check(self, route, timeout=None):
def _check(self, route):
if route in self.states:
# If we're current waiting, join the club
# If the route is being cooled off, we need to wait until its ready
if self.states[route].chilled:
return self.states[route].wait(timeout)
return self.states[route].wait()
if self.states[route].next_will_ratelimit:
gevent.spawn(self.states[route].cooldown).get(True, timeout)
return gevent.spawn(self.states[route].cooldown).get()
return True
return 0
def update(self, route, response):
"""

15
examples/basic_plugin.py

@ -1,3 +1,5 @@
from __future__ import print_function
from disco.bot import Plugin
from disco.util.sanitize import S
@ -10,6 +12,19 @@ class BasicPlugin(Plugin):
# channel = event.guild.create_channel('audit-log-test', 'text', reason='TEST CREATE')
# channel.delete(reason='TEST AUDIT 2')
@Plugin.command('ratelimitme')
def on_ratelimitme(self, event):
msg = event.msg.reply('Hi!')
with self.client.api.capture() as requests:
for i in range(6):
msg.edit('Hi {}!'.format(i))
print('Rate limited {} for {}'.format(
requests.rate_limited,
requests.rate_limited_duration(),
))
@Plugin.command('ban', '<user:snowflake> <reason:str...>')
def on_ban(self, event, user, reason):
event.guild.create_ban(user, reason=reason + u'\U0001F4BF')

18
tests/api/client.py

@ -0,0 +1,18 @@
from disco.api.client import Responses
from disco.api.http import APIResponse
def test_responses_list():
r = Responses()
r.append(APIResponse())
r.append(APIResponse())
assert not r.rate_limited
assert r.rate_limited_duration() == 0
res = APIResponse()
res.rate_limited_duration = 5.5
r.append(res)
assert r.rate_limited
assert r.rate_limited_duration() == 5.5
Loading…
Cancel
Save