Browse Source

Capture Raw API Responses (#46)

* Add support for capturing and viewing API responses

* Just subclass list, naming changes, tests
pull/47/head
Andrei Zbikowski 8 years ago
committed by GitHub
parent
commit
9e9d6bb1b1
  1. 50
      disco/api/client.py
  2. 21
      disco/api/http.py
  3. 49
      disco/api/ratelimit.py
  4. 15
      examples/basic_plugin.py
  5. 18
      tests/api/client.py

50
disco/api/client.py

@ -2,6 +2,8 @@ import six
import json import json
import warnings import warnings
from contextlib import contextmanager
from gevent.local import local
from six.moves.urllib.parse import quote from six.moves.urllib.parse import quote
from disco.api.http import Routes, HTTPClient, to_bytes from disco.api.http import Routes, HTTPClient, to_bytes
@ -29,6 +31,15 @@ def _reason_header(value):
return optional(**{'X-Audit-Log-Reason': quote(to_bytes(value)) if value else None}) return optional(**{'X-Audit-Log-Reason': quote(to_bytes(value)) if value else None})
class Responses(list):
def rate_limited_duration(self):
return sum([i.rate_limited_duration for i in self])
@property
def rate_limited(self):
return self.rate_limited_duration() != 0
class APIClient(LoggingClass): class APIClient(LoggingClass):
""" """
An abstraction over a :class:`disco.api.http.HTTPClient`, which composes An abstraction over a :class:`disco.api.http.HTTPClient`, which composes
@ -56,7 +67,30 @@ class APIClient(LoggingClass):
super(APIClient, self).__init__() super(APIClient, self).__init__()
self.client = client self.client = client
self.http = HTTPClient(token) self.http = HTTPClient(token, self._after_requests)
self._captures = local()
def _after_requests(self, response):
if not hasattr(self._captures, 'responses'):
return
self._captures.responses.append(response)
@contextmanager
def capture(self):
"""
Context manager which captures all requests made, returning a special
`Responses` list, which can be used to introspect raw API responses. This
method is a low-level utility which should only be used by experienced users.
"""
responses = Responses()
self._captures.responses = responses
try:
yield responses
finally:
delattr(self._captures, 'responses')
def gateway_get(self): def gateway_get(self):
data = self.http(Routes.GATEWAY_GET).json() data = self.http(Routes.GATEWAY_GET).json()
@ -195,7 +229,10 @@ class APIClient(LoggingClass):
}, headers=_reason_header(reason)) }, headers=_reason_header(reason))
def channels_permissions_delete(self, channel, permission, reason=None): def channels_permissions_delete(self, channel, permission, reason=None):
self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission), headers=_reason_header(reason)) self.http(
Routes.CHANNELS_PERMISSIONS_DELETE,
dict(channel=channel, permission=permission), headers=_reason_header(reason)
)
def channels_invites_list(self, channel): def channels_invites_list(self, channel):
r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel)) r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel))
@ -247,7 +284,14 @@ class APIClient(LoggingClass):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild) return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[], reason=None): def guilds_channels_create(self,
guild,
name,
channel_type,
bitrate=None,
user_limit=None,
permission_overwrites=[],
reason=None):
payload = { payload = {
'name': name, 'name': name,
'channel_type': channel_type, 'channel_type': channel_type,

21
disco/api/http.py

@ -132,6 +132,13 @@ class Routes(object):
WEBHOOKS_TOKEN_EXECUTE = (HTTPMethod.POST, WEBHOOKS + '/{token}') WEBHOOKS_TOKEN_EXECUTE = (HTTPMethod.POST, WEBHOOKS + '/{token}')
class APIResponse(object):
def __init__(self):
self.response = None
self.exception = None
self.rate_limited_duration = 0
class APIException(Exception): class APIException(Exception):
""" """
Exception thrown when an HTTP-client level error occurs. Usually this will Exception thrown when an HTTP-client level error occurs. Usually this will
@ -183,7 +190,7 @@ class HTTPClient(LoggingClass):
BASE_URL = 'https://discordapp.com/api/v7' BASE_URL = 'https://discordapp.com/api/v7'
MAX_RETRIES = 5 MAX_RETRIES = 5
def __init__(self, token): def __init__(self, token, after_request=None):
super(HTTPClient, self).__init__() super(HTTPClient, self).__init__()
py_version = '{}.{}.{}'.format( py_version = '{}.{}.{}'.format(
@ -202,6 +209,7 @@ class HTTPClient(LoggingClass):
if token: if token:
self.headers['Authorization'] = 'Bot ' + token self.headers['Authorization'] = 'Bot ' + token
self.after_request = after_request
self.session = requests.Session() self.session = requests.Session()
def __call__(self, route, args=None, **kwargs): def __call__(self, route, args=None, **kwargs):
@ -251,8 +259,10 @@ class HTTPClient(LoggingClass):
filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)} filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)}
bucket = (route[0].value, route[1].format(**filtered)) bucket = (route[0].value, route[1].format(**filtered))
response = APIResponse()
# Possibly wait if we're rate limited # Possibly wait if we're rate limited
self.limiter.check(bucket) response.rate_limited_duration = self.limiter.check(bucket)
self.log.debug('KW: %s', kwargs) self.log.debug('KW: %s', kwargs)
@ -261,6 +271,10 @@ class HTTPClient(LoggingClass):
self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params')) self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params'))
r = self.session.request(route[0].value, url, **kwargs) r = self.session.request(route[0].value, url, **kwargs)
if self.after_request:
response.response = r
self.after_request(response)
# Update rate limiter # Update rate limiter
self.limiter.update(bucket, r) self.limiter.update(bucket, r)
@ -269,7 +283,8 @@ class HTTPClient(LoggingClass):
return r return r
elif r.status_code != 429 and 400 <= r.status_code < 500: elif r.status_code != 429 and 400 <= r.status_code < 500:
self.log.warning('Request failed with code %s: %s', r.status_code, r.content) self.log.warning('Request failed with code %s: %s', r.status_code, r.content)
raise APIException(r) response.exception = APIException(r)
raise response.exception
else: else:
if r.status_code == 429: if r.status_code == 429:
self.log.warning( self.log.warning(

49
disco/api/ratelimit.py

@ -1,6 +1,7 @@
import time import time
import gevent import gevent
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
@ -76,18 +77,18 @@ class RouteState(LoggingClass):
""" """
Waits until this route is no longer under a cooldown. Waits until this route is no longer under a cooldown.
Parameters
----------
timeout : Optional[int]
A timeout (in seconds) after which we will give up waiting
Returns Returns
------- -------
bool float
False if the timeout period expired before the cooldown was finished. The duration we waited for, in seconds or zero if we didn't have to
wait at all.
""" """
return self.event.wait(timeout) if self.event.is_set():
return 0
start = time.time()
self.event.wait()
return time.time() - start
def cooldown(self): def cooldown(self):
""" """
@ -102,6 +103,7 @@ class RouteState(LoggingClass):
gevent.sleep(delay) gevent.sleep(delay)
self.event.set() self.event.set()
self.event = None self.event = None
return delay
class RateLimiter(LoggingClass): class RateLimiter(LoggingClass):
@ -117,40 +119,37 @@ class RateLimiter(LoggingClass):
def __init__(self): def __init__(self):
self.states = {} self.states = {}
def check(self, route, timeout=None): def check(self, route):
""" """
Checks whether a given route can be called. This function will return Checks whether a given route can be called. This function will return
immediately if no rate-limit cooldown is being imposed for the given immediately if no rate-limit cooldown is being imposed for the given
route, or will wait indefinitely (unless timeout is specified) until route, or will wait indefinitely until the route is finished being
the route is finished being cooled down. This function should be called cooled down. This function should be called before making a request to
before making a request to the specified route. the specified route.
Parameters Parameters
---------- ----------
route : tuple(HTTPMethod, str) route : tuple(HTTPMethod, str)
The route that will be checked. The route that will be checked.
timeout : Optional[int]
A timeout after which we'll give up waiting for a route's cooldown
to expire, and immediately return.
Returns Returns
------- -------
bool float
False if the timeout period expired before the route finished cooling The number of seconds we had to wait for this rate limit, or zero
down. if no time was waited.
""" """
return self._check(None, timeout) and self._check(route, timeout) return self._check(None) + self._check(route)
def _check(self, route, timeout=None): def _check(self, route):
if route in self.states: if route in self.states:
# If we're current waiting, join the club # If the route is being cooled off, we need to wait until its ready
if self.states[route].chilled: if self.states[route].chilled:
return self.states[route].wait(timeout) return self.states[route].wait()
if self.states[route].next_will_ratelimit: if self.states[route].next_will_ratelimit:
gevent.spawn(self.states[route].cooldown).get(True, timeout) return gevent.spawn(self.states[route].cooldown).get()
return True return 0
def update(self, route, response): def update(self, route, response):
""" """

15
examples/basic_plugin.py

@ -1,3 +1,5 @@
from __future__ import print_function
from disco.bot import Plugin from disco.bot import Plugin
from disco.util.sanitize import S from disco.util.sanitize import S
@ -10,6 +12,19 @@ class BasicPlugin(Plugin):
# channel = event.guild.create_channel('audit-log-test', 'text', reason='TEST CREATE') # channel = event.guild.create_channel('audit-log-test', 'text', reason='TEST CREATE')
# channel.delete(reason='TEST AUDIT 2') # channel.delete(reason='TEST AUDIT 2')
@Plugin.command('ratelimitme')
def on_ratelimitme(self, event):
msg = event.msg.reply('Hi!')
with self.client.api.capture() as requests:
for i in range(6):
msg.edit('Hi {}!'.format(i))
print('Rate limited {} for {}'.format(
requests.rate_limited,
requests.rate_limited_duration(),
))
@Plugin.command('ban', '<user:snowflake> <reason:str...>') @Plugin.command('ban', '<user:snowflake> <reason:str...>')
def on_ban(self, event, user, reason): def on_ban(self, event, user, reason):
event.guild.create_ban(user, reason=reason + u'\U0001F4BF') event.guild.create_ban(user, reason=reason + u'\U0001F4BF')

18
tests/api/client.py

@ -0,0 +1,18 @@
from disco.api.client import Responses
from disco.api.http import APIResponse
def test_responses_list():
r = Responses()
r.append(APIResponse())
r.append(APIResponse())
assert not r.rate_limited
assert r.rate_limited_duration() == 0
res = APIResponse()
res.rate_limited_duration = 5.5
r.append(res)
assert r.rate_limited
assert r.rate_limited_duration() == 5.5
Loading…
Cancel
Save