From 13a3f760e6de8aa251f06bbe2a746e5f92deafd2 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 10 May 2020 19:30:46 -0400 Subject: [PATCH] Fix timeout issues with fetching members via query_members This uses the nonce field to properly disambiguate queries. There's also some redesigning going on behind the scenes and minor clean-up. Originally I planned on working on this more to account for the more widespread chunking changes planned for gateway v7 but I realized that this would indiscriminately slow down everyone else who isn't planning on working with intents for now. I will work on the larger chunking changes in the future, should time allow for it. --- discord/client.py | 24 +++++------------------- discord/gateway.py | 8 ++++++-- discord/shard.py | 31 ++++++------------------------- discord/state.py | 37 +++++++++++++++++++++++++------------ 4 files changed, 42 insertions(+), 58 deletions(-) diff --git a/discord/client.py b/discord/client.py index f292c9d21..03845fd15 100644 --- a/discord/client.py +++ b/discord/client.py @@ -223,13 +223,13 @@ class Client: 'ready': self._handle_ready } - self._connection = ConnectionState(dispatch=self.dispatch, chunker=self._chunker, handlers=self._handlers, + self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers, syncer=self._syncer, http=self.http, loop=self.loop, **options) self._connection.shard_count = self.shard_count self._closed = False self._ready = asyncio.Event() - self._connection._get_websocket = lambda g: self.ws + self._connection._get_websocket = self._get_websocket if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False @@ -237,26 +237,12 @@ class Client: # internals + def _get_websocket(self, guild_id=None, *, shard_id=None): + return self.ws + async def _syncer(self, guilds): await self.ws.request_sync(guilds) - async def _chunker(self, guild): - try: - guild_id = guild.id - except AttributeError: - guild_id = [s.id for s in guild] - - payload = { - 'op': 8, - 'd': { - 'guild_id': guild_id, - 'query': '', - 'limit': 0 - } - } - - await self.ws.send_as_json(payload) - def _handle_ready(self): self._ready.set() diff --git a/discord/gateway.py b/discord/gateway.py index 9b0f1d819..15368d565 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -535,15 +535,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): } await self.send_as_json(payload) - async def request_chunks(self, guild_id, query, limit): + async def request_chunks(self, guild_id, query, limit, *, nonce=None): payload = { 'op': self.REQUEST_MEMBERS, 'd': { - 'guild_id': str(guild_id), + 'guild_id': guild_id, 'query': query, 'limit': limit } } + + if nonce: + payload['d']['nonce'] = nonce + await self.send_as_json(payload) async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): diff --git a/discord/shard.py b/discord/shard.py index e133cd0ce..6d599dab6 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -126,38 +126,19 @@ class AutoShardedClient(Client): elif not isinstance(self.shard_ids, (list, tuple)): raise ClientException('shard_ids parameter must be a list or a tuple.') - self._connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self._chunker, + self._connection = AutoShardedConnectionState(dispatch=self.dispatch, handlers=self._handlers, syncer=self._syncer, http=self.http, loop=self.loop, **kwargs) # instead of a single websocket, we have multiple # the key is the shard_id self.shards = {} + self._connection._get_websocket = self._get_websocket - def _get_websocket(guild_id): - i = (guild_id >> 22) % self.shard_count - return self.shards[i].ws - - self._connection._get_websocket = _get_websocket - - async def _chunker(self, guild, *, shard_id=None): - try: - guild_id = guild.id - shard_id = shard_id or guild.shard_id - except AttributeError: - guild_id = [s.id for s in guild] - - payload = { - 'op': 8, - 'd': { - 'guild_id': guild_id, - 'query': '', - 'limit': 0 - } - } - - ws = self.shards[shard_id].ws - await ws.send_as_json(payload) + def _get_websocket(self, guild_id=None, *, shard_id=None): + if shard_id is None: + shard_id = (guild_id >> 22) % self.shard_count + return self.shards[shard_id].ws @property def latency(self): diff --git a/discord/state.py b/discord/state.py index e630724e7..0e8764b6c 100644 --- a/discord/state.py +++ b/discord/state.py @@ -35,6 +35,9 @@ import weakref import inspect import gc +import os +import binascii + from .guild import Guild from .activity import BaseActivity from .user import User, ClientUser @@ -62,7 +65,7 @@ log = logging.getLogger(__name__) ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) class ConnectionState: - def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options): + def __init__(self, *, dispatch, handlers, syncer, http, loop, **options): self.loop = loop self.http = http self.max_messages = options.get('max_messages', 1000) @@ -70,7 +73,6 @@ class ConnectionState: self.max_messages = 1000 self.dispatch = dispatch - self.chunker = chunker self.syncer = syncer self.is_bot = None self.handlers = handlers @@ -132,6 +134,9 @@ class ConnectionState: # to reconnect loops which cause mass allocations and deallocations. gc.collect() + def get_nonce(self): + return binascii.hexlify(os.urandom(16)).decode('ascii') + def process_listeners(self, listener_type, argument, result): removed = [] for i, listener in enumerate(self._listeners): @@ -298,6 +303,10 @@ class ConnectionState: return channel or Object(id=channel_id), guild + async def chunker(self, guild_id, query='', limit=0, *, nonce=None): + ws = self._get_websocket(guild_id) # This is ignored upstream + await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) + async def request_offline_members(self, guilds): # get all the chunks chunks = [] @@ -307,7 +316,7 @@ class ConnectionState: # we only want to request ~75 guilds per chunk request. splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] for split in splits: - await self.chunker(split) + await self.chunker([g.id for g in split]) # wait for the chunks if chunks: @@ -329,10 +338,11 @@ class ConnectionState: # and they don't receive GUILD_MEMBER events which make computing # member_count impossible. The only way to fix it is by limiting # the limit parameter to 1 to 1000. - future = self.receive_member_query(guild_id, query) + nonce = self.get_nonce() + future = self.receive_member_query(guild_id, nonce) try: # start the query operation - await ws.request_chunks(guild_id, query, limit) + await ws.request_chunks(guild_id, query, limit, nonce=nonce) members = await asyncio.wait_for(future, timeout=5.0) if cache: @@ -894,8 +904,7 @@ class ConnectionState: guild._add_member(member) self.process_listeners(ListenerType.chunk, guild, len(members)) - names = [x.name.lower() for x in members] - self.process_listeners(ListenerType.query_members, (guild_id, names), members) + self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members) def parse_guild_integrations_update(self, data): guild = self._get_guild(int(data['guild_id'])) @@ -1025,10 +1034,10 @@ class ConnectionState: self._listeners.append(listener) return future - def receive_member_query(self, guild_id, query): - def predicate(args, *, guild_id=guild_id, query=query.lower()): - request_guild_id, names = args - return request_guild_id == guild_id and all(n.startswith(query) for n in names) + def receive_member_query(self, guild_id, nonce): + def predicate(args, *, guild_id=guild_id, nonce=nonce): + return args == (guild_id, nonce) + future = self.loop.create_future() listener = Listener(ListenerType.query_members, future, predicate) self._listeners.append(listener) @@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState): self._ready_task = None self.shard_ids = () + async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None): + ws = self._get_websocket(shard_id=shard_id) + await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) + async def request_offline_members(self, guilds, *, shard_id): # get all the chunks chunks = [] @@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState): # we only want to request ~75 guilds per chunk request. splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] for split in splits: - await self.chunker(split, shard_id=shard_id) + await self.chunker([g.id for g in split], shard_id=shard_id) # wait for the chunks if chunks: