From 1a6295dffb40632be5670eacc8f1e880108423d8 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 19 Oct 2020 02:46:04 -0400 Subject: [PATCH] Allow concurrent calls to guild.chunk() This allows people who write guild.chunk() calls in highly concurrent places such as on_message or checks to not spam the gateway with an actual request and instead waits for the pre-existing request to finish --- discord/state.py | 58 +++++++++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/discord/state.py b/discord/state.py index 791ffc1f2..5937fec32 100644 --- a/discord/state.py +++ b/discord/state.py @@ -58,13 +58,14 @@ from .object import Object from .invite import Invite class ChunkRequest: - def __init__(self, guild_id, future, resolver, *, cache=True): + def __init__(self, guild_id, loop, resolver, *, cache=True): self.guild_id = guild_id self.resolver = resolver + self.loop = loop self.cache = cache self.nonce = os.urandom(16).hex() - self.future = future self.buffer = [] # List[Member] + self.waiters = [] def add_members(self, members): self.buffer.extend(members) @@ -78,8 +79,24 @@ class ChunkRequest: if existing is None or existing.joined_at is None: guild._add_member(member) + async def wait(self): + future = self.loop.create_future() + self.waiters.append(future) + try: + await future + return True + finally: + self.waiters.remove(future) + + def get_future(self): + future = self.loop.create_future() + self.waiters.append(future) + return future + def done(self): - self.future.set_result(self.buffer) + for future in self.waiters: + if not future.done(): + future.set_result(self.buffer) log = logging.getLogger(__name__) @@ -116,7 +133,7 @@ class ConnectionState: raise TypeError('allowed_mentions parameter must be AllowedMentions') self.allowed_mentions = allowed_mentions - self._chunk_requests = [] + self._chunk_requests = {} # Dict[Union[int, str], ChunkRequest] activity = options.get('activity', None) if activity: @@ -198,20 +215,15 @@ class ConnectionState: def process_chunk_requests(self, guild_id, nonce, members, complete): removed = [] - for i, request in enumerate(self._chunk_requests): - future = request.future - if future.cancelled(): - removed.append(i) - continue - + for key, request in self._chunk_requests.items(): if request.guild_id == guild_id and request.nonce == nonce: request.add_members(members) if complete: request.done() - removed.append(i) + removed.append(key) - for index in reversed(removed): - del self._chunk_requests[index] + for key in removed: + del self._chunk_requests[key] def call_handlers(self, key, *args, **kwargs): try: @@ -377,14 +389,13 @@ class ConnectionState: if ws is None: raise RuntimeError('Somehow do not have a websocket for this guild_id') - future = self.loop.create_future() - request = ChunkRequest(guild.id, future, self._get_guild, cache=cache) - self._chunk_requests.append(request) + request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + self._chunk_requests[request.nonce] = request try: # start the query operation await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=request.nonce) - return await asyncio.wait_for(future, timeout=30.0) + return await asyncio.wait_for(request.wait(), timeout=30.0) except asyncio.TimeoutError: log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) raise @@ -808,13 +819,14 @@ class ConnectionState: async def chunk_guild(self, guild, *, wait=True, cache=None): cache = cache or self._member_cache_flags.joined - future = self.loop.create_future() - request = ChunkRequest(guild.id, future, self._get_guild, cache=cache) - self._chunk_requests.append(request) - await self.chunker(guild.id, nonce=request.nonce) + request = self._chunk_requests.get(guild.id) + if request is None: + self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + await self.chunker(guild.id, nonce=request.nonce) + if wait: - return await request.future - return request.future + return await request.wait() + return request.get_future() async def _chunk_and_dispatch(self, guild, unavailable): try: