diff --git a/disco/util/limiter.py b/disco/util/limiter.py index ccb7622..493d4ae 100644 --- a/disco/util/limiter.py +++ b/disco/util/limiter.py @@ -6,33 +6,17 @@ class SimpleLimiter(object): def __init__(self, total, per): self.total = total self.per = per + self._lock = gevent.lock.Semaphore(total) self.count = 0 self.reset_at = 0 - - self.event = None - - def backoff(self): - self.event = gevent.event.Event() - gevent.sleep(self.reset_at - time.time()) - self.count = 0 - self.reset_at = 0 - if self.event: - self.event.set() self.event = None def check(self): - if self.event: - self.event.wait() - - self.count += 1 + self._lock.acquire() - if not self.reset_at: - self.reset_at = time.time() + self.per - return - elif self.reset_at < time.time(): - self.count = 1 - self.reset_at = time.time() + def _release_lock(): + gevent.sleep(self.per) + self._lock.release() - if self.count > self.total and self.reset_at > time.time(): - self.backoff() + gevent.spawn(_release_lock) diff --git a/tests/test_util_limiter.py b/tests/test_util_limiter.py new file mode 100644 index 0000000..0c1fabc --- /dev/null +++ b/tests/test_util_limiter.py @@ -0,0 +1,46 @@ +import time +import gevent +from unittest import TestCase + +from disco.util.limiter import SimpleLimiter + + +class TestSimpleLimiter(TestCase): + def test_many_wait_ratelimiter(self): + limit = SimpleLimiter(5, 1) + many = [] + + def check(lock): + limit.check() + lock.release() + + start = time.time() + for _ in range(16): + lock = gevent.lock.Semaphore() + lock.acquire() + many.append(lock) + gevent.spawn(check, lock) + + for item in many: + item.acquire() + + self.assertGreater(time.time() - start, 3) + + def test_nowait_ratelimiter(self): + limit = SimpleLimiter(5, 1) + + start = time.time() + for _ in range(5): + limit.check() + + self.assertLess(time.time() - start, 1) + + def test_single_wait_ratelimiter(self): + limit = SimpleLimiter(5, 1) + + start = time.time() + for _ in range(10): + limit.check() + + + self.assertEqual(int(time.time() - start), 1)