From 9e9d6bb1b1ba8ec1806e1f869cd831449c683648 Mon Sep 17 00:00:00 2001
From: Andrei Zbikowski <b1naryth1ef@users.noreply.github.com>
Date: Wed, 19 Jul 2017 21:02:47 -0700
Subject: [PATCH] Capture Raw API Responses (#46)

* Add support for capturing and viewing API responses

* Just subclass list, naming changes, tests
---
 disco/api/client.py      | 50 +++++++++++++++++++++++++++++++++++++---
 disco/api/http.py        | 21 ++++++++++++++---
 disco/api/ratelimit.py   | 49 +++++++++++++++++++--------------------
 examples/basic_plugin.py | 15 ++++++++++++
 tests/api/client.py      | 18 +++++++++++++++
 5 files changed, 122 insertions(+), 31 deletions(-)
 create mode 100644 tests/api/client.py

diff --git a/disco/api/client.py b/disco/api/client.py
index d67dadf..37c7413 100644
--- a/disco/api/client.py
+++ b/disco/api/client.py
@@ -2,6 +2,8 @@ import six
 import json
 import warnings
 
+from contextlib import contextmanager
+from gevent.local import local
 from six.moves.urllib.parse import quote
 
 from disco.api.http import Routes, HTTPClient, to_bytes
@@ -29,6 +31,15 @@ def _reason_header(value):
     return optional(**{'X-Audit-Log-Reason': quote(to_bytes(value)) if value else None})
 
 
+class Responses(list):
+    def rate_limited_duration(self):
+        return sum([i.rate_limited_duration for i in self])
+
+    @property
+    def rate_limited(self):
+        return self.rate_limited_duration() != 0
+
+
 class APIClient(LoggingClass):
     """
     An abstraction over a :class:`disco.api.http.HTTPClient`, which composes
@@ -56,7 +67,30 @@ class APIClient(LoggingClass):
         super(APIClient, self).__init__()
 
         self.client = client
-        self.http = HTTPClient(token)
+        self.http = HTTPClient(token, self._after_requests)
+
+        self._captures = local()
+
+    def _after_requests(self, response):
+        if not hasattr(self._captures, 'responses'):
+            return
+
+        self._captures.responses.append(response)
+
+    @contextmanager
+    def capture(self):
+        """
+        Context manager which captures all requests made, returning a special
+        `Responses` list, which can be used to introspect raw API responses. This
+        method is a low-level utility which should only be used by experienced users.
+        """
+        responses = Responses()
+        self._captures.responses = responses
+
+        try:
+            yield responses
+        finally:
+            delattr(self._captures, 'responses')
 
     def gateway_get(self):
         data = self.http(Routes.GATEWAY_GET).json()
@@ -195,7 +229,10 @@ class APIClient(LoggingClass):
         }, headers=_reason_header(reason))
 
     def channels_permissions_delete(self, channel, permission, reason=None):
-        self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission), headers=_reason_header(reason))
+        self.http(
+            Routes.CHANNELS_PERMISSIONS_DELETE,
+            dict(channel=channel, permission=permission), headers=_reason_header(reason)
+        )
 
     def channels_invites_list(self, channel):
         r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel))
@@ -247,7 +284,14 @@ class APIClient(LoggingClass):
         r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
         return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild)
 
-    def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[], reason=None):
+    def guilds_channels_create(self,
+            guild,
+            name,
+            channel_type,
+            bitrate=None,
+            user_limit=None,
+            permission_overwrites=[],
+            reason=None):
         payload = {
             'name': name,
             'channel_type': channel_type,
diff --git a/disco/api/http.py b/disco/api/http.py
index e429a28..2950fcc 100644
--- a/disco/api/http.py
+++ b/disco/api/http.py
@@ -132,6 +132,13 @@ class Routes(object):
     WEBHOOKS_TOKEN_EXECUTE = (HTTPMethod.POST, WEBHOOKS + '/{token}')
 
 
+class APIResponse(object):
+    def __init__(self):
+        self.response = None
+        self.exception = None
+        self.rate_limited_duration = 0
+
+
 class APIException(Exception):
     """
     Exception thrown when an HTTP-client level error occurs. Usually this will
@@ -183,7 +190,7 @@ class HTTPClient(LoggingClass):
     BASE_URL = 'https://discordapp.com/api/v7'
     MAX_RETRIES = 5
 
-    def __init__(self, token):
+    def __init__(self, token, after_request=None):
         super(HTTPClient, self).__init__()
 
         py_version = '{}.{}.{}'.format(
@@ -202,6 +209,7 @@ class HTTPClient(LoggingClass):
         if token:
             self.headers['Authorization'] = 'Bot ' + token
 
+        self.after_request = after_request
         self.session = requests.Session()
 
     def __call__(self, route, args=None, **kwargs):
@@ -251,8 +259,10 @@ class HTTPClient(LoggingClass):
         filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)}
         bucket = (route[0].value, route[1].format(**filtered))
 
+        response = APIResponse()
+
         # Possibly wait if we're rate limited
-        self.limiter.check(bucket)
+        response.rate_limited_duration = self.limiter.check(bucket)
 
         self.log.debug('KW: %s', kwargs)
 
@@ -261,6 +271,10 @@ class HTTPClient(LoggingClass):
         self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params'))
         r = self.session.request(route[0].value, url, **kwargs)
 
+        if self.after_request:
+            response.response = r
+            self.after_request(response)
+
         # Update rate limiter
         self.limiter.update(bucket, r)
 
@@ -269,7 +283,8 @@ class HTTPClient(LoggingClass):
             return r
         elif r.status_code != 429 and 400 <= r.status_code < 500:
             self.log.warning('Request failed with code %s: %s', r.status_code, r.content)
-            raise APIException(r)
+            response.exception = APIException(r)
+            raise response.exception
         else:
             if r.status_code == 429:
                 self.log.warning(
diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py
index 054c8cf..ab50488 100644
--- a/disco/api/ratelimit.py
+++ b/disco/api/ratelimit.py
@@ -1,6 +1,7 @@
 import time
 import gevent
 
+
 from disco.util.logging import LoggingClass
 
 
@@ -76,18 +77,18 @@ class RouteState(LoggingClass):
         """
         Waits until this route is no longer under a cooldown.
 
-        Parameters
-        ----------
-        timeout : Optional[int]
-            A timeout (in seconds) after which we will give up waiting
-
-
         Returns
         -------
-        bool
-            False if the timeout period expired before the cooldown was finished.
+        float
+            The duration we waited for, in seconds or zero if we didn't have to
+            wait at all.
         """
-        return self.event.wait(timeout)
+        if self.event.is_set():
+            return 0
+
+        start = time.time()
+        self.event.wait()
+        return time.time() - start
 
     def cooldown(self):
         """
@@ -102,6 +103,7 @@ class RouteState(LoggingClass):
         gevent.sleep(delay)
         self.event.set()
         self.event = None
+        return delay
 
 
 class RateLimiter(LoggingClass):
@@ -117,40 +119,37 @@ class RateLimiter(LoggingClass):
     def __init__(self):
         self.states = {}
 
-    def check(self, route, timeout=None):
+    def check(self, route):
         """
         Checks whether a given route can be called. This function will return
         immediately if no rate-limit cooldown is being imposed for the given
-        route, or will wait indefinitely (unless timeout is specified) until
-        the route is finished being cooled down. This function should be called
-        before making a request to the specified route.
+        route, or will wait indefinitely until the route is finished being
+        cooled down. This function should be called before making a request to
+        the specified route.
 
         Parameters
         ----------
         route : tuple(HTTPMethod, str)
             The route that will be checked.
-        timeout : Optional[int]
-            A timeout after which we'll give up waiting for a route's cooldown
-            to expire, and immediately return.
 
         Returns
         -------
-        bool
-            False if the timeout period expired before the route finished cooling
-            down.
+        float
+            The number of seconds we had to wait for this rate limit, or zero
+            if no time was waited.
         """
-        return self._check(None, timeout) and self._check(route, timeout)
+        return self._check(None) + self._check(route)
 
-    def _check(self, route, timeout=None):
+    def _check(self, route):
         if route in self.states:
-            # If we're current waiting, join the club
+            # If the route is being cooled off, we need to wait until its ready
             if self.states[route].chilled:
-                return self.states[route].wait(timeout)
+                return self.states[route].wait()
 
             if self.states[route].next_will_ratelimit:
-                gevent.spawn(self.states[route].cooldown).get(True, timeout)
+                return gevent.spawn(self.states[route].cooldown).get()
 
-        return True
+        return 0
 
     def update(self, route, response):
         """
diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py
index e7ad8ec..b615aa8 100644
--- a/examples/basic_plugin.py
+++ b/examples/basic_plugin.py
@@ -1,3 +1,5 @@
+from __future__ import print_function
+
 from disco.bot import Plugin
 from disco.util.sanitize import S
 
@@ -10,6 +12,19 @@ class BasicPlugin(Plugin):
         # channel = event.guild.create_channel('audit-log-test', 'text', reason='TEST CREATE')
         # channel.delete(reason='TEST AUDIT 2')
 
+    @Plugin.command('ratelimitme')
+    def on_ratelimitme(self, event):
+        msg = event.msg.reply('Hi!')
+
+        with self.client.api.capture() as requests:
+            for i in range(6):
+                msg.edit('Hi {}!'.format(i))
+
+        print('Rate limited {} for {}'.format(
+            requests.rate_limited,
+            requests.rate_limited_duration(),
+        ))
+
     @Plugin.command('ban', '<user:snowflake> <reason:str...>')
     def on_ban(self, event, user, reason):
         event.guild.create_ban(user, reason=reason + u'\U0001F4BF')
diff --git a/tests/api/client.py b/tests/api/client.py
new file mode 100644
index 0000000..cbc5c83
--- /dev/null
+++ b/tests/api/client.py
@@ -0,0 +1,18 @@
+from disco.api.client import Responses
+from disco.api.http import APIResponse
+
+
+def test_responses_list():
+    r = Responses()
+    r.append(APIResponse())
+    r.append(APIResponse())
+
+    assert not r.rate_limited
+    assert r.rate_limited_duration() == 0
+
+    res = APIResponse()
+    res.rate_limited_duration = 5.5
+    r.append(res)
+
+    assert r.rate_limited
+    assert r.rate_limited_duration() == 5.5