diff --git a/discord/http.py b/discord/http.py index e795ccd1e..365970336 100644 --- a/discord/http.py +++ b/discord/http.py @@ -47,7 +47,6 @@ from typing import ( ) from urllib.parse import quote as _uriquote from collections import deque -import weakref import datetime import aiohttp @@ -329,6 +328,19 @@ class Ratelimit: everything into a single lock queue per route. """ + __slots__ = ( + 'limit', + 'remaining', + 'outgoing', + 'reset_after', + 'expires', + 'dirty', + '_max_ratelimit_timeout', + '_loop', + '_pending_requests', + '_sleeping', + ) + def __init__(self, max_ratelimit_timeout: Optional[float]) -> None: self.limit: int = 1 self.remaining: int = self.limit @@ -392,7 +404,6 @@ class Ratelimit: future.set_exception(exception) else: future.set_result(None) - self._has_just_awaken = True awaken += 1 if awaken >= count: @@ -481,12 +492,12 @@ class HTTPClient: # Route key -> Bucket hash self._bucket_hashes: Dict[str, str] = {} # Bucket Hash + Major Parameters -> Rate limit - self._buckets: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary() + # or # Route key + Major Parameters -> Rate limit - # Used for temporary one shot requests that don't have a bucket hash - # While I'd love to use a single mapping for these, doing this would cause the rate limit objects - # to inexplicably be evicted from the dictionary before we're done with it - self._oneshots: weakref.WeakValueDictionary[str, Ratelimit] = weakref.WeakValueDictionary() + # When the key is the latter, it is used for temporary + # one shot requests that don't have a bucket hash + # When this reaches 256 elements, it will try to evict based off of expiry + self._buckets: Dict[str, Ratelimit] = {} self._global_over: asyncio.Event = MISSING self.token: Optional[str] = None self.proxy: Optional[str] = proxy @@ -517,6 +528,22 @@ class HTTPClient: return await self.__session.ws_connect(url, **kwargs) + def _try_clear_expired_ratelimits(self) -> None: + if len(self._buckets) < 256: + return + + keys = [key for key, bucket in self._buckets.items() if bucket.is_expired()] + for key in keys: + del self._buckets[key] + + def get_ratelimit(self, key: str) -> Ratelimit: + try: + value = self._buckets[key] + except KeyError: + self._buckets[key] = value = Ratelimit(self.max_ratelimit_timeout) + self._try_clear_expired_ratelimits() + return value + async def request( self, route: Route, @@ -533,16 +560,11 @@ class HTTPClient: try: bucket_hash = self._bucket_hashes[route_key] except KeyError: - key = route_key + route.major_parameters - mapping = self._oneshots + key = f'{route_key}:{route.major_parameters}' else: - key = bucket_hash + route.major_parameters - mapping = self._buckets + key = f'{bucket_hash}:{route.major_parameters}' - try: - ratelimit = mapping[key] - except KeyError: - mapping[key] = ratelimit = Ratelimit(self.max_ratelimit_timeout) + ratelimit = self.get_ratelimit(key) # header creation headers: Dict[str, str] = {