From 668390543eb7a537daf2dff2584d81cadb08b39e Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 22 Sep 2016 16:40:21 -0500 Subject: [PATCH] Implement rate limiting --- README.md | 2 +- disco/api/http.py | 62 ++++++++++++++++++++++++++------------ disco/api/ratelimit.py | 64 ++++++++++++++++++++++++++++++++++++++++ disco/bot/bot.py | 1 - disco/bot/command.py | 2 +- examples/basic_plugin.py | 8 ++++- 6 files changed, 116 insertions(+), 23 deletions(-) create mode 100644 disco/api/ratelimit.py diff --git a/README.md b/README.md index 13e3f30..339ee16 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ A Discord Python bot built to be easy to use and scale. ## TODOS -- rate limits - flesh out gateway paths (reconnect/resume) - flesh out API client +- storage/database/config - flesh out type methods - plugin reload - voice support diff --git a/disco/api/http.py b/disco/api/http.py index 1ab05d2..93a7b37 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -2,6 +2,9 @@ import requests from holster.enum import Enum +from disco.util.logging import LoggingClass +from disco.api.ratelimit import RateLimiter + HTTPMethod = Enum( GET='GET', POST='POST', @@ -21,36 +24,57 @@ class Routes(object): class APIException(Exception): - def __init__(self, obj): - self.code = obj['code'] - self.msg = obj['msg'] - - super(APIException, self).__init__(self.msg) + def __init__(self, msg, status_code=0, content=None): + super(APIException, self).__init__(msg) + self.status_code = status_code + self.content = content -class HTTPClient(object): +class HTTPClient(LoggingClass): BASE_URL = 'https://discordapp.com/api' + MAX_RETRIES = 5 def __init__(self, token): + super(HTTPClient, self).__init__() + + self.limiter = RateLimiter() self.headers = { 'Authorization': 'Bot ' + token, } def __call__(self, route, *args, **kwargs): - method, url = route + retry = kwargs.pop('retry_number', 0) + + # Merge or set headers + if 'headers' in kwargs: + kwargs['headers'].update(self.headers) + else: + kwargs['headers'] = self.headers + + # Compile URL args + compiled = (str(route[0]), (self.BASE_URL) + route[1].format(*args)) + + # Possibly wait if we're rate limited + self.limiter.check(compiled) + + # Make the actual request + r = requests.request(compiled[0], compiled[1], **kwargs) - kwargs['headers'] = self.headers + # Update rate limiter + self.limiter.update(compiled, r) - r = requests.request(str(method), (self.BASE_URL + url).format(*args), **kwargs) + # If we got a success status code, just return the data + if r.status_code < 400: + return r.json() + else: + if r.status_code == 429: + self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync') - try: - r.raise_for_status() - except: - print r.json() - raise - # TODO: rate limits - # TODO: check json - raise APIException(r.json()) + # If we hit the max retries, throw an error + retry += 1 + if retry > self.MAX_RETRIES: + self.log.error('Failing request, hit max retries') + raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) - # TODO: check json - return r.json() + # Otherwise just recurse and try again + return self(route, retry_number=retry, *args, **kwargs) diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py new file mode 100644 index 0000000..d0acc25 --- /dev/null +++ b/disco/api/ratelimit.py @@ -0,0 +1,64 @@ +import time +import gevent + + +class RouteState(object): + def __init__(self, route, request): + self.route = route + self.remaining = 0 + self.reset_time = 0 + self.event = None + + self.update(request) + + @property + def chilled(self): + return self.event is not None + + def update(self, request): + if 'X-RateLimit-Remaining' not in request.headers: + return + + self.remaining = int(request.headers.get('X-RateLimit-Remaining')) + self.reset_time = int(request.headers.get('X-RateLimit-Reset')) + + def wait(self, timeout=None): + self.event.wait(timeout) + + def next_will_ratelimit(self): + if self.remaining - 1 < 0 and time.time() <= self.reset_time: + return True + + return False + + def cooldown(self): + if self.reset_time - time.time() < 0: + raise Exception('Cannot cooldown for negative time period; check clock sync') + + self.event = gevent.event.Event() + gevent.sleep((self.reset_time - time.time()) + .5) + self.event.set() + self.event = None + + +class RateLimiter(object): + def __init__(self): + self.cooldowns = {} + self.states = {} + + def check(self, route, timeout=None): + if route in self.states: + # If we're current waiting, join the club + if self.states[route].chilled: + return self.states[route].wait(timeout) + + if self.states[route].next_will_ratelimit(): + self.states[route].cooldown() + + return True + + def update(self, route, request): + if route in self.states: + self.states[route].update(request) + else: + self.states[route] = RouteState(route, request) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 25d95c2..18a9646 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -44,7 +44,6 @@ class Bot(object): def compute_command_matches_re(self): re_str = '|'.join(command.regex for command in self.commands) - print re_str if re_str: self.command_matches_re = re.compile(re_str) else: diff --git a/disco/bot/command.py b/disco/bot/command.py index c6788eb..e3248de 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -9,7 +9,7 @@ class CommandEvent(object): def __init__(self, msg, match): self.msg = msg self.match = match - self.args = self.match.group(1).split(' ') + self.args = self.match.group(1).strip().split(' ') class Command(object): diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index d88ad7c..9df21f3 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -11,7 +11,13 @@ class BasicPlugin(Plugin): @Plugin.command('test') def on_test_command(self, event): event.msg.reply('HELLO WORLD') - print 'wtf' + + @Plugin.command('spam') + def on_spam_command(self, event): + count = int(event.args[0]) + + for i in range(count): + event.msg.reply(' '.join(event.args[1:])) if __name__ == '__main__': bot = Bot(disco_main())