Browse Source

Allow concurrent calls to guild.chunk()

This allows people who write guild.chunk() calls in highly concurrent
places such as on_message or checks to not spam the gateway with an
actual request and instead waits for the pre-existing request to finish
v1.5.x
Rapptz 5 years ago
parent
commit
1a6295dffb
  1. 58
      discord/state.py

58
discord/state.py

@ -58,13 +58,14 @@ from .object import Object
from .invite import Invite from .invite import Invite
class ChunkRequest: class ChunkRequest:
def __init__(self, guild_id, future, resolver, *, cache=True): def __init__(self, guild_id, loop, resolver, *, cache=True):
self.guild_id = guild_id self.guild_id = guild_id
self.resolver = resolver self.resolver = resolver
self.loop = loop
self.cache = cache self.cache = cache
self.nonce = os.urandom(16).hex() self.nonce = os.urandom(16).hex()
self.future = future
self.buffer = [] # List[Member] self.buffer = [] # List[Member]
self.waiters = []
def add_members(self, members): def add_members(self, members):
self.buffer.extend(members) self.buffer.extend(members)
@ -78,8 +79,24 @@ class ChunkRequest:
if existing is None or existing.joined_at is None: if existing is None or existing.joined_at is None:
guild._add_member(member) guild._add_member(member)
async def wait(self):
future = self.loop.create_future()
self.waiters.append(future)
try:
await future
return True
finally:
self.waiters.remove(future)
def get_future(self):
future = self.loop.create_future()
self.waiters.append(future)
return future
def done(self): def done(self):
self.future.set_result(self.buffer) for future in self.waiters:
if not future.done():
future.set_result(self.buffer)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -116,7 +133,7 @@ class ConnectionState:
raise TypeError('allowed_mentions parameter must be AllowedMentions') raise TypeError('allowed_mentions parameter must be AllowedMentions')
self.allowed_mentions = allowed_mentions self.allowed_mentions = allowed_mentions
self._chunk_requests = [] self._chunk_requests = {} # Dict[Union[int, str], ChunkRequest]
activity = options.get('activity', None) activity = options.get('activity', None)
if activity: if activity:
@ -198,20 +215,15 @@ class ConnectionState:
def process_chunk_requests(self, guild_id, nonce, members, complete): def process_chunk_requests(self, guild_id, nonce, members, complete):
removed = [] removed = []
for i, request in enumerate(self._chunk_requests): for key, request in self._chunk_requests.items():
future = request.future
if future.cancelled():
removed.append(i)
continue
if request.guild_id == guild_id and request.nonce == nonce: if request.guild_id == guild_id and request.nonce == nonce:
request.add_members(members) request.add_members(members)
if complete: if complete:
request.done() request.done()
removed.append(i) removed.append(key)
for index in reversed(removed): for key in removed:
del self._chunk_requests[index] del self._chunk_requests[key]
def call_handlers(self, key, *args, **kwargs): def call_handlers(self, key, *args, **kwargs):
try: try:
@ -377,14 +389,13 @@ class ConnectionState:
if ws is None: if ws is None:
raise RuntimeError('Somehow do not have a websocket for this guild_id') raise RuntimeError('Somehow do not have a websocket for this guild_id')
future = self.loop.create_future() request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
request = ChunkRequest(guild.id, future, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request
self._chunk_requests.append(request)
try: try:
# start the query operation # start the query operation
await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=request.nonce) await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=request.nonce)
return await asyncio.wait_for(future, timeout=30.0) return await asyncio.wait_for(request.wait(), timeout=30.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id)
raise raise
@ -808,13 +819,14 @@ class ConnectionState:
async def chunk_guild(self, guild, *, wait=True, cache=None): async def chunk_guild(self, guild, *, wait=True, cache=None):
cache = cache or self._member_cache_flags.joined cache = cache or self._member_cache_flags.joined
future = self.loop.create_future() request = self._chunk_requests.get(guild.id)
request = ChunkRequest(guild.id, future, self._get_guild, cache=cache) if request is None:
self._chunk_requests.append(request) self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
await self.chunker(guild.id, nonce=request.nonce) await self.chunker(guild.id, nonce=request.nonce)
if wait: if wait:
return await request.future return await request.wait()
return request.future return request.get_future()
async def _chunk_and_dispatch(self, guild, unavailable): async def _chunk_and_dispatch(self, guild, unavailable):
try: try:

Loading…
Cancel
Save