From 2dbf14bb726faaf5c6b2d4dd4033834c710dbce9 Mon Sep 17 00:00:00 2001 From: Lilly Rose Berner Date: Mon, 25 Apr 2022 08:01:46 +0200 Subject: [PATCH] Separately delay ready event for each shard --- discord/client.py | 1 - discord/gateway.py | 2 - discord/shard.py | 2 - discord/state.py | 155 ++++++++++++++++++++------------------- discord/types/gateway.py | 7 +- 5 files changed, 82 insertions(+), 85 deletions(-) diff --git a/discord/client.py b/discord/client.py index aee47af78..09b0af256 100644 --- a/discord/client.py +++ b/discord/client.py @@ -484,7 +484,6 @@ class Client: self.loop = loop self.http.loop = loop self._connection.loop = loop - await self._connection.async_setup() self._ready = asyncio.Event() diff --git a/discord/gateway.py b/discord/gateway.py index 9dba51752..e757c150f 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -546,8 +546,6 @@ class DiscordWebSocket: self._trace = trace = data.get('_trace', []) self.sequence = msg['s'] self.session_id = data['session_id'] - # pass back shard ID to ready handler - data['__shard_id__'] = self.shard_id _log.info( 'Shard ID %s has connected to Gateway: %s (Session ID: %s).', self.shard_id, diff --git a/discord/shard.py b/discord/shard.py index c44931c79..bd16c18c1 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -423,8 +423,6 @@ class AutoShardedClient(Client): initial = shard_id == shard_ids[0] await self.launch_shard(gateway, shard_id, initial=initial) - self._connection.shards_launched.set() - async def _async_setup_hook(self) -> None: await super()._async_setup_hook() self.__queue = asyncio.PriorityQueue() diff --git a/discord/state.py b/discord/state.py index 9e16a2a6e..3069075e8 100644 --- a/discord/state.py +++ b/discord/state.py @@ -27,7 +27,6 @@ from __future__ import annotations import asyncio from collections import deque, OrderedDict import copy -import itertools import logging from typing import ( Dict, @@ -302,9 +301,6 @@ class ConnectionState: else: await coro(*args, **kwargs) - async def async_setup(self) -> None: - pass - @property def self_id(self) -> Optional[int]: u = self.user @@ -561,7 +557,7 @@ class ConnectionState: if self._ready_task is not None: self._ready_task.cancel() - self._ready_state = asyncio.Queue() + self._ready_state: asyncio.Queue[Guild] = asyncio.Queue() self.clear(views=False) self.user = user = ClientUser(state=self, data=data['user']) self._users[user.id] = user # type: ignore @@ -1111,6 +1107,15 @@ class ConnectionState: else: self.dispatch('guild_join', guild) + def _add_ready_state(self, guild: Guild) -> bool: + try: + # Notify the on_ready state, if any, that this guild is complete. + self._ready_state.put_nowait(guild) + except AttributeError: + return False + else: + return True + def parse_guild_create(self, data: gw.GuildCreateEvent) -> None: unavailable = data.get('unavailable') if unavailable is True: @@ -1119,14 +1124,8 @@ class ConnectionState: guild = self._get_create_guild(data) - try: - # Notify the on_ready state, if any, that this guild is complete. - self._ready_state.put_nowait(guild) - except AttributeError: - pass - else: - # If we're waiting for the event, put the rest on hold - return + if self._add_ready_state(guild): + return # We're waiting for the ready event, put the rest on hold # check if it requires chunking if self._guild_needs_chunking(guild): @@ -1510,8 +1509,12 @@ class ConnectionState: class AutoShardedConnectionState(ConnectionState): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + self.shard_ids: Union[List[int], range] = [] + self._ready_tasks: Dict[int, asyncio.Task[None]] = {} + self._ready_states: Dict[int, asyncio.Queue[Guild]] = {} + def _update_message_references(self) -> None: # self._messages won't be None when this is called for msg in self._messages: # type: ignore @@ -1525,9 +1528,6 @@ class AutoShardedConnectionState(ConnectionState): # channel will either be a TextChannel, Thread or Object msg._rebind_cached_references(new_guild, channel) # type: ignore - async def async_setup(self) -> None: - self.shards_launched: asyncio.Event = asyncio.Event() - async def chunker( self, guild_id: int, @@ -1541,76 +1541,80 @@ class AutoShardedConnectionState(ConnectionState): ws = self._get_websocket(guild_id, shard_id=shard_id) await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + def _add_ready_state(self, guild: Guild) -> bool: + try: + # Notify the on_ready state, if any, that this guild is complete. + self._ready_states[guild.shard_id].put_nowait(guild) + except KeyError: + return False + else: + return True + async def _delay_ready(self) -> None: - await self.shards_launched.wait() - processed = [] - max_concurrency = len(self.shard_ids) * 2 - current_bucket = [] - while True: - # this snippet of code is basically waiting N seconds - # until the last GUILD_CREATE was sent - try: - guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) - except asyncio.TimeoutError: - break - else: - if self._guild_needs_chunking(guild): - _log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) - if len(current_bucket) >= max_concurrency: - try: - await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) - except asyncio.TimeoutError: - fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' - _log.warning(fmt, guild.shard_id, len(current_bucket)) - finally: - current_bucket = [] - - # Chunk the guild in the background while we wait for GUILD_CREATE streaming - future = asyncio.ensure_future(self.chunk_guild(guild)) - current_bucket.append(future) + await asyncio.gather(*self._ready_tasks.values()) + + # clear the current tasks + self._ready_task = None + self._ready_tasks = {} + + # dispatch the event + self.call_handlers('ready') + self.dispatch('ready') + + async def _delay_shard_ready(self, shard_id: int) -> None: + try: + states = [] + while True: + # this snippet of code is basically waiting N seconds + # until the last GUILD_CREATE was sent + try: + guild = await asyncio.wait_for(self._ready_states[shard_id].get(), timeout=self.guild_ready_timeout) + except asyncio.TimeoutError: + break else: - future = self.loop.create_future() - future.set_result([]) + if self._guild_needs_chunking(guild): + future = await self.chunk_guild(guild, wait=False) + states.append((guild, future)) + else: + if guild.unavailable is False: + self.dispatch('guild_available', guild) + else: + self.dispatch('guild_join', guild) - processed.append((guild, future)) + for guild, future in states: + try: + await asyncio.wait_for(future, timeout=5.0) + except asyncio.TimeoutError: + _log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id) - guilds = sorted(processed, key=lambda g: g[0].shard_id) - for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id): - children, futures = zip(*info) - # 110 reqs/minute w/ 1 req/guild plus some buffer - timeout = 61 * (len(children) / 110) - try: - await utils.sane_wait_for(futures, timeout=timeout) - except asyncio.TimeoutError: - _log.warning( - 'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) - ) - for guild in children: if guild.unavailable is False: self.dispatch('guild_available', guild) else: self.dispatch('guild_join', guild) - self.dispatch('shard_ready', shard_id) + # remove the state + try: + del self._ready_states[shard_id] + except KeyError: + pass # already been deleted somehow - # remove the state - try: - del self._ready_state - except AttributeError: - pass # already been deleted somehow + except asyncio.CancelledError: + pass + else: + # dispatch the event + self.dispatch('shard_ready', shard_id) - # regular users cannot shard so we won't worry about it here. + def parse_ready(self, data: gw.ReadyEvent) -> None: + if self._ready_task is not None: + self._ready_task.cancel() - # clear the current task - self._ready_task = None + shard_id = data['shard'][0] # shard_id, num_shards - # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + if shard_id in self._ready_tasks: + self._ready_tasks[shard_id].cancel() - def parse_ready(self, data: gw.ReadyEvent) -> None: - if not hasattr(self, '_ready_state'): - self._ready_state = asyncio.Queue() + if shard_id not in self._ready_states: + self._ready_states[shard_id] = asyncio.Queue() self.user: Optional[ClientUser] self.user = user = ClientUser(state=self, data=data['user']) @@ -1633,9 +1637,12 @@ class AutoShardedConnectionState(ConnectionState): self._update_message_references() self.dispatch('connect') - self.dispatch('shard_connect', data['__shard_id__']) # type: ignore # This is an internal discord.py key + self.dispatch('shard_connect', shard_id) + + self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id)) - if self._ready_task is None: + # The delay task for every shard has been started + if len(self._ready_tasks) == len(self.shard_ids): self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data: gw.ResumedEvent) -> None: diff --git a/discord/types/gateway.py b/discord/types/gateway.py index 266e40d8c..fb6e21e43 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -60,17 +60,12 @@ class GatewayBot(Gateway): session_start_limit: SessionStartLimit -class ShardInfo(TypedDict): - shard_id: int - shard_count: int - - class ReadyEvent(TypedDict): v: int user: User guilds: List[UnavailableGuild] session_id: str - shard: ShardInfo + shard: List[int] # shard_id, num_shards application: GatewayAppInfo