Browse Source

Fix timeout issues with fetching members via query_members

This uses the nonce field to properly disambiguate queries. There's
also some redesigning going on behind the scenes and minor clean-up.
Originally I planned on working on this more to account for the more
widespread chunking changes planned for gateway v7 but I realized that
this would indiscriminately slow down everyone else who isn't planning
on working with intents for now.

I will work on the larger chunking changes in the future, should time
allow for it.
pull/4088/head
Rapptz 5 years ago
parent
commit
13a3f760e6
  1. 24
      discord/client.py
  2. 8
      discord/gateway.py
  3. 31
      discord/shard.py
  4. 37
      discord/state.py

24
discord/client.py

@ -223,13 +223,13 @@ class Client:
'ready': self._handle_ready 'ready': self._handle_ready
} }
self._connection = ConnectionState(dispatch=self.dispatch, chunker=self._chunker, handlers=self._handlers, self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
syncer=self._syncer, http=self.http, loop=self.loop, **options) syncer=self._syncer, http=self.http, loop=self.loop, **options)
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count
self._closed = False self._closed = False
self._ready = asyncio.Event() self._ready = asyncio.Event()
self._connection._get_websocket = lambda g: self.ws self._connection._get_websocket = self._get_websocket
if VoiceClient.warn_nacl: if VoiceClient.warn_nacl:
VoiceClient.warn_nacl = False VoiceClient.warn_nacl = False
@ -237,26 +237,12 @@ class Client:
# internals # internals
def _get_websocket(self, guild_id=None, *, shard_id=None):
return self.ws
async def _syncer(self, guilds): async def _syncer(self, guilds):
await self.ws.request_sync(guilds) await self.ws.request_sync(guilds)
async def _chunker(self, guild):
try:
guild_id = guild.id
except AttributeError:
guild_id = [s.id for s in guild]
payload = {
'op': 8,
'd': {
'guild_id': guild_id,
'query': '',
'limit': 0
}
}
await self.ws.send_as_json(payload)
def _handle_ready(self): def _handle_ready(self):
self._ready.set() self._ready.set()

8
discord/gateway.py

@ -535,15 +535,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
} }
await self.send_as_json(payload) await self.send_as_json(payload)
async def request_chunks(self, guild_id, query, limit): async def request_chunks(self, guild_id, query, limit, *, nonce=None):
payload = { payload = {
'op': self.REQUEST_MEMBERS, 'op': self.REQUEST_MEMBERS,
'd': { 'd': {
'guild_id': str(guild_id), 'guild_id': guild_id,
'query': query, 'query': query,
'limit': limit 'limit': limit
} }
} }
if nonce:
payload['d']['nonce'] = nonce
await self.send_as_json(payload) await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):

31
discord/shard.py

@ -126,38 +126,19 @@ class AutoShardedClient(Client):
elif not isinstance(self.shard_ids, (list, tuple)): elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException('shard_ids parameter must be a list or a tuple.') raise ClientException('shard_ids parameter must be a list or a tuple.')
self._connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self._chunker, self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers, syncer=self._syncer, handlers=self._handlers, syncer=self._syncer,
http=self.http, loop=self.loop, **kwargs) http=self.http, loop=self.loop, **kwargs)
# instead of a single websocket, we have multiple # instead of a single websocket, we have multiple
# the key is the shard_id # the key is the shard_id
self.shards = {} self.shards = {}
self._connection._get_websocket = self._get_websocket
def _get_websocket(guild_id): def _get_websocket(self, guild_id=None, *, shard_id=None):
i = (guild_id >> 22) % self.shard_count if shard_id is None:
return self.shards[i].ws shard_id = (guild_id >> 22) % self.shard_count
return self.shards[shard_id].ws
self._connection._get_websocket = _get_websocket
async def _chunker(self, guild, *, shard_id=None):
try:
guild_id = guild.id
shard_id = shard_id or guild.shard_id
except AttributeError:
guild_id = [s.id for s in guild]
payload = {
'op': 8,
'd': {
'guild_id': guild_id,
'query': '',
'limit': 0
}
}
ws = self.shards[shard_id].ws
await ws.send_as_json(payload)
@property @property
def latency(self): def latency(self):

37
discord/state.py

@ -35,6 +35,9 @@ import weakref
import inspect import inspect
import gc import gc
import os
import binascii
from .guild import Guild from .guild import Guild
from .activity import BaseActivity from .activity import BaseActivity
from .user import User, ClientUser from .user import User, ClientUser
@ -62,7 +65,7 @@ log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
class ConnectionState: class ConnectionState:
def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options): def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
self.loop = loop self.loop = loop
self.http = http self.http = http
self.max_messages = options.get('max_messages', 1000) self.max_messages = options.get('max_messages', 1000)
@ -70,7 +73,6 @@ class ConnectionState:
self.max_messages = 1000 self.max_messages = 1000
self.dispatch = dispatch self.dispatch = dispatch
self.chunker = chunker
self.syncer = syncer self.syncer = syncer
self.is_bot = None self.is_bot = None
self.handlers = handlers self.handlers = handlers
@ -132,6 +134,9 @@ class ConnectionState:
# to reconnect loops which cause mass allocations and deallocations. # to reconnect loops which cause mass allocations and deallocations.
gc.collect() gc.collect()
def get_nonce(self):
return binascii.hexlify(os.urandom(16)).decode('ascii')
def process_listeners(self, listener_type, argument, result): def process_listeners(self, listener_type, argument, result):
removed = [] removed = []
for i, listener in enumerate(self._listeners): for i, listener in enumerate(self._listeners):
@ -298,6 +303,10 @@ class ConnectionState:
return channel or Object(id=channel_id), guild return channel or Object(id=channel_id), guild
async def chunker(self, guild_id, query='', limit=0, *, nonce=None):
ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
async def request_offline_members(self, guilds): async def request_offline_members(self, guilds):
# get all the chunks # get all the chunks
chunks = [] chunks = []
@ -307,7 +316,7 @@ class ConnectionState:
# we only want to request ~75 guilds per chunk request. # we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits: for split in splits:
await self.chunker(split) await self.chunker([g.id for g in split])
# wait for the chunks # wait for the chunks
if chunks: if chunks:
@ -329,10 +338,11 @@ class ConnectionState:
# and they don't receive GUILD_MEMBER events which make computing # and they don't receive GUILD_MEMBER events which make computing
# member_count impossible. The only way to fix it is by limiting # member_count impossible. The only way to fix it is by limiting
# the limit parameter to 1 to 1000. # the limit parameter to 1 to 1000.
future = self.receive_member_query(guild_id, query) nonce = self.get_nonce()
future = self.receive_member_query(guild_id, nonce)
try: try:
# start the query operation # start the query operation
await ws.request_chunks(guild_id, query, limit) await ws.request_chunks(guild_id, query, limit, nonce=nonce)
members = await asyncio.wait_for(future, timeout=5.0) members = await asyncio.wait_for(future, timeout=5.0)
if cache: if cache:
@ -894,8 +904,7 @@ class ConnectionState:
guild._add_member(member) guild._add_member(member)
self.process_listeners(ListenerType.chunk, guild, len(members)) self.process_listeners(ListenerType.chunk, guild, len(members))
names = [x.name.lower() for x in members] self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members)
self.process_listeners(ListenerType.query_members, (guild_id, names), members)
def parse_guild_integrations_update(self, data): def parse_guild_integrations_update(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
@ -1025,10 +1034,10 @@ class ConnectionState:
self._listeners.append(listener) self._listeners.append(listener)
return future return future
def receive_member_query(self, guild_id, query): def receive_member_query(self, guild_id, nonce):
def predicate(args, *, guild_id=guild_id, query=query.lower()): def predicate(args, *, guild_id=guild_id, nonce=nonce):
request_guild_id, names = args return args == (guild_id, nonce)
return request_guild_id == guild_id and all(n.startswith(query) for n in names)
future = self.loop.create_future() future = self.loop.create_future()
listener = Listener(ListenerType.query_members, future, predicate) listener = Listener(ListenerType.query_members, future, predicate)
self._listeners.append(listener) self._listeners.append(listener)
@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState):
self._ready_task = None self._ready_task = None
self.shard_ids = () self.shard_ids = ()
async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None):
ws = self._get_websocket(shard_id=shard_id)
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
async def request_offline_members(self, guilds, *, shard_id): async def request_offline_members(self, guilds, *, shard_id):
# get all the chunks # get all the chunks
chunks = [] chunks = []
@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState):
# we only want to request ~75 guilds per chunk request. # we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits: for split in splits:
await self.chunker(split, shard_id=shard_id) await self.chunker([g.id for g in split], shard_id=shard_id)
# wait for the chunks # wait for the chunks
if chunks: if chunks:

Loading…
Cancel
Save