diff --git a/disco/api/client.py b/disco/api/client.py index d67dadf..37c7413 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -2,6 +2,8 @@ import six import json import warnings +from contextlib import contextmanager +from gevent.local import local from six.moves.urllib.parse import quote from disco.api.http import Routes, HTTPClient, to_bytes @@ -29,6 +31,15 @@ def _reason_header(value): return optional(**{'X-Audit-Log-Reason': quote(to_bytes(value)) if value else None}) +class Responses(list): + def rate_limited_duration(self): + return sum([i.rate_limited_duration for i in self]) + + @property + def rate_limited(self): + return self.rate_limited_duration() != 0 + + class APIClient(LoggingClass): """ An abstraction over a :class:`disco.api.http.HTTPClient`, which composes @@ -56,7 +67,30 @@ class APIClient(LoggingClass): super(APIClient, self).__init__() self.client = client - self.http = HTTPClient(token) + self.http = HTTPClient(token, self._after_requests) + + self._captures = local() + + def _after_requests(self, response): + if not hasattr(self._captures, 'responses'): + return + + self._captures.responses.append(response) + + @contextmanager + def capture(self): + """ + Context manager which captures all requests made, returning a special + `Responses` list, which can be used to introspect raw API responses. This + method is a low-level utility which should only be used by experienced users. + """ + responses = Responses() + self._captures.responses = responses + + try: + yield responses + finally: + delattr(self._captures, 'responses') def gateway_get(self): data = self.http(Routes.GATEWAY_GET).json() @@ -195,7 +229,10 @@ class APIClient(LoggingClass): }, headers=_reason_header(reason)) def channels_permissions_delete(self, channel, permission, reason=None): - self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission), headers=_reason_header(reason)) + self.http( + Routes.CHANNELS_PERMISSIONS_DELETE, + dict(channel=channel, permission=permission), headers=_reason_header(reason) + ) def channels_invites_list(self, channel): r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel)) @@ -247,7 +284,14 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild) - def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[], reason=None): + def guilds_channels_create(self, + guild, + name, + channel_type, + bitrate=None, + user_limit=None, + permission_overwrites=[], + reason=None): payload = { 'name': name, 'channel_type': channel_type, diff --git a/disco/api/http.py b/disco/api/http.py index e429a28..2950fcc 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -132,6 +132,13 @@ class Routes(object): WEBHOOKS_TOKEN_EXECUTE = (HTTPMethod.POST, WEBHOOKS + '/{token}') +class APIResponse(object): + def __init__(self): + self.response = None + self.exception = None + self.rate_limited_duration = 0 + + class APIException(Exception): """ Exception thrown when an HTTP-client level error occurs. Usually this will @@ -183,7 +190,7 @@ class HTTPClient(LoggingClass): BASE_URL = 'https://discordapp.com/api/v7' MAX_RETRIES = 5 - def __init__(self, token): + def __init__(self, token, after_request=None): super(HTTPClient, self).__init__() py_version = '{}.{}.{}'.format( @@ -202,6 +209,7 @@ class HTTPClient(LoggingClass): if token: self.headers['Authorization'] = 'Bot ' + token + self.after_request = after_request self.session = requests.Session() def __call__(self, route, args=None, **kwargs): @@ -251,8 +259,10 @@ class HTTPClient(LoggingClass): filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)} bucket = (route[0].value, route[1].format(**filtered)) + response = APIResponse() + # Possibly wait if we're rate limited - self.limiter.check(bucket) + response.rate_limited_duration = self.limiter.check(bucket) self.log.debug('KW: %s', kwargs) @@ -261,6 +271,10 @@ class HTTPClient(LoggingClass): self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params')) r = self.session.request(route[0].value, url, **kwargs) + if self.after_request: + response.response = r + self.after_request(response) + # Update rate limiter self.limiter.update(bucket, r) @@ -269,7 +283,8 @@ class HTTPClient(LoggingClass): return r elif r.status_code != 429 and 400 <= r.status_code < 500: self.log.warning('Request failed with code %s: %s', r.status_code, r.content) - raise APIException(r) + response.exception = APIException(r) + raise response.exception else: if r.status_code == 429: self.log.warning( diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py index 054c8cf..ab50488 100644 --- a/disco/api/ratelimit.py +++ b/disco/api/ratelimit.py @@ -1,6 +1,7 @@ import time import gevent + from disco.util.logging import LoggingClass @@ -76,18 +77,18 @@ class RouteState(LoggingClass): """ Waits until this route is no longer under a cooldown. - Parameters - ---------- - timeout : Optional[int] - A timeout (in seconds) after which we will give up waiting - - Returns ------- - bool - False if the timeout period expired before the cooldown was finished. + float + The duration we waited for, in seconds or zero if we didn't have to + wait at all. """ - return self.event.wait(timeout) + if self.event.is_set(): + return 0 + + start = time.time() + self.event.wait() + return time.time() - start def cooldown(self): """ @@ -102,6 +103,7 @@ class RouteState(LoggingClass): gevent.sleep(delay) self.event.set() self.event = None + return delay class RateLimiter(LoggingClass): @@ -117,40 +119,37 @@ class RateLimiter(LoggingClass): def __init__(self): self.states = {} - def check(self, route, timeout=None): + def check(self, route): """ Checks whether a given route can be called. This function will return immediately if no rate-limit cooldown is being imposed for the given - route, or will wait indefinitely (unless timeout is specified) until - the route is finished being cooled down. This function should be called - before making a request to the specified route. + route, or will wait indefinitely until the route is finished being + cooled down. This function should be called before making a request to + the specified route. Parameters ---------- route : tuple(HTTPMethod, str) The route that will be checked. - timeout : Optional[int] - A timeout after which we'll give up waiting for a route's cooldown - to expire, and immediately return. Returns ------- - bool - False if the timeout period expired before the route finished cooling - down. + float + The number of seconds we had to wait for this rate limit, or zero + if no time was waited. """ - return self._check(None, timeout) and self._check(route, timeout) + return self._check(None) + self._check(route) - def _check(self, route, timeout=None): + def _check(self, route): if route in self.states: - # If we're current waiting, join the club + # If the route is being cooled off, we need to wait until its ready if self.states[route].chilled: - return self.states[route].wait(timeout) + return self.states[route].wait() if self.states[route].next_will_ratelimit: - gevent.spawn(self.states[route].cooldown).get(True, timeout) + return gevent.spawn(self.states[route].cooldown).get() - return True + return 0 def update(self, route, response): """ diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index e7ad8ec..b615aa8 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -1,3 +1,5 @@ +from __future__ import print_function + from disco.bot import Plugin from disco.util.sanitize import S @@ -10,6 +12,19 @@ class BasicPlugin(Plugin): # channel = event.guild.create_channel('audit-log-test', 'text', reason='TEST CREATE') # channel.delete(reason='TEST AUDIT 2') + @Plugin.command('ratelimitme') + def on_ratelimitme(self, event): + msg = event.msg.reply('Hi!') + + with self.client.api.capture() as requests: + for i in range(6): + msg.edit('Hi {}!'.format(i)) + + print('Rate limited {} for {}'.format( + requests.rate_limited, + requests.rate_limited_duration(), + )) + @Plugin.command('ban', ' ') def on_ban(self, event, user, reason): event.guild.create_ban(user, reason=reason + u'\U0001F4BF') diff --git a/tests/api/client.py b/tests/api/client.py new file mode 100644 index 0000000..cbc5c83 --- /dev/null +++ b/tests/api/client.py @@ -0,0 +1,18 @@ +from disco.api.client import Responses +from disco.api.http import APIResponse + + +def test_responses_list(): + r = Responses() + r.append(APIResponse()) + r.append(APIResponse()) + + assert not r.rate_limited + assert r.rate_limited_duration() == 0 + + res = APIResponse() + res.rate_limited_duration = 5.5 + r.append(res) + + assert r.rate_limited + assert r.rate_limited_duration() == 5.5