Browse Source

Implement rate limiting

pull/3/head
Andrei 9 years ago
parent
commit
668390543e
  1. 2
      README.md
  2. 58
      disco/api/http.py
  3. 64
      disco/api/ratelimit.py
  4. 1
      disco/bot/bot.py
  5. 2
      disco/bot/command.py
  6. 8
      examples/basic_plugin.py

2
README.md

@ -3,9 +3,9 @@ A Discord Python bot built to be easy to use and scale.
## TODOS
- rate limits
- flesh out gateway paths (reconnect/resume)
- flesh out API client
- storage/database/config
- flesh out type methods
- plugin reload
- voice support

58
disco/api/http.py

@ -2,6 +2,9 @@ import requests
from holster.enum import Enum
from disco.util.logging import LoggingClass
from disco.api.ratelimit import RateLimiter
HTTPMethod = Enum(
GET='GET',
POST='POST',
@ -21,36 +24,57 @@ class Routes(object):
class APIException(Exception):
def __init__(self, obj):
self.code = obj['code']
self.msg = obj['msg']
super(APIException, self).__init__(self.msg)
def __init__(self, msg, status_code=0, content=None):
super(APIException, self).__init__(msg)
self.status_code = status_code
self.content = content
class HTTPClient(object):
class HTTPClient(LoggingClass):
BASE_URL = 'https://discordapp.com/api'
MAX_RETRIES = 5
def __init__(self, token):
super(HTTPClient, self).__init__()
self.limiter = RateLimiter()
self.headers = {
'Authorization': 'Bot ' + token,
}
def __call__(self, route, *args, **kwargs):
method, url = route
retry = kwargs.pop('retry_number', 0)
# Merge or set headers
if 'headers' in kwargs:
kwargs['headers'].update(self.headers)
else:
kwargs['headers'] = self.headers
r = requests.request(str(method), (self.BASE_URL + url).format(*args), **kwargs)
# Compile URL args
compiled = (str(route[0]), (self.BASE_URL) + route[1].format(*args))
# Possibly wait if we're rate limited
self.limiter.check(compiled)
# Make the actual request
r = requests.request(compiled[0], compiled[1], **kwargs)
try:
r.raise_for_status()
except:
print r.json()
raise
# TODO: rate limits
# TODO: check json
raise APIException(r.json())
# Update rate limiter
self.limiter.update(compiled, r)
# TODO: check json
# If we got a success status code, just return the data
if r.status_code < 400:
return r.json()
else:
if r.status_code == 429:
self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync')
# If we hit the max retries, throw an error
retry += 1
if retry > self.MAX_RETRIES:
self.log.error('Failing request, hit max retries')
raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content)
# Otherwise just recurse and try again
return self(route, retry_number=retry, *args, **kwargs)

64
disco/api/ratelimit.py

@ -0,0 +1,64 @@
import time
import gevent
class RouteState(object):
def __init__(self, route, request):
self.route = route
self.remaining = 0
self.reset_time = 0
self.event = None
self.update(request)
@property
def chilled(self):
return self.event is not None
def update(self, request):
if 'X-RateLimit-Remaining' not in request.headers:
return
self.remaining = int(request.headers.get('X-RateLimit-Remaining'))
self.reset_time = int(request.headers.get('X-RateLimit-Reset'))
def wait(self, timeout=None):
self.event.wait(timeout)
def next_will_ratelimit(self):
if self.remaining - 1 < 0 and time.time() <= self.reset_time:
return True
return False
def cooldown(self):
if self.reset_time - time.time() < 0:
raise Exception('Cannot cooldown for negative time period; check clock sync')
self.event = gevent.event.Event()
gevent.sleep((self.reset_time - time.time()) + .5)
self.event.set()
self.event = None
class RateLimiter(object):
def __init__(self):
self.cooldowns = {}
self.states = {}
def check(self, route, timeout=None):
if route in self.states:
# If we're current waiting, join the club
if self.states[route].chilled:
return self.states[route].wait(timeout)
if self.states[route].next_will_ratelimit():
self.states[route].cooldown()
return True
def update(self, route, request):
if route in self.states:
self.states[route].update(request)
else:
self.states[route] = RouteState(route, request)

1
disco/bot/bot.py

@ -44,7 +44,6 @@ class Bot(object):
def compute_command_matches_re(self):
re_str = '|'.join(command.regex for command in self.commands)
print re_str
if re_str:
self.command_matches_re = re.compile(re_str)
else:

2
disco/bot/command.py

@ -9,7 +9,7 @@ class CommandEvent(object):
def __init__(self, msg, match):
self.msg = msg
self.match = match
self.args = self.match.group(1).split(' ')
self.args = self.match.group(1).strip().split(' ')
class Command(object):

8
examples/basic_plugin.py

@ -11,7 +11,13 @@ class BasicPlugin(Plugin):
@Plugin.command('test')
def on_test_command(self, event):
event.msg.reply('HELLO WORLD')
print 'wtf'
@Plugin.command('spam')
def on_spam_command(self, event):
count = int(event.args[0])
for i in range(count):
event.msg.reply(' '.join(event.args[1:]))
if __name__ == '__main__':
bot = Bot(disco_main())

Loading…
Cancel
Save