diff --git a/discord/state.py b/discord/state.py index 22b0e74e0..b3da4eabf 100644 --- a/discord/state.py +++ b/discord/state.py @@ -110,12 +110,14 @@ class ChunkRequest: def __init__( self, guild_id: int, + shard_id: int, loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True, ) -> None: self.guild_id: int = guild_id + self.shard_id: int = shard_id self.resolver: Callable[[int], Any] = resolver self.loop: asyncio.AbstractEventLoop = loop self.cache: bool = cache @@ -315,6 +317,16 @@ class ConnectionState(Generic[ClientT]): for key in removed: del self._chunk_requests[key] + def clear_chunk_requests(self, shard_id: int | None) -> None: + removed = [] + for key, request in self._chunk_requests.items(): + if shard_id is None or request.shard_id == shard_id: + request.done() + removed.append(key) + + for key in removed: + del self._chunk_requests[key] + def call_handlers(self, key: str, *args: Any, **kwargs: Any) -> None: try: func = self.handlers[key] @@ -535,7 +547,7 @@ class ConnectionState(Generic[ClientT]): if ws is None: raise RuntimeError('Somehow do not have a websocket for this guild_id') - request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + request = ChunkRequest(guild.id, guild.shard_id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request try: @@ -602,6 +614,7 @@ class ConnectionState(Generic[ClientT]): self._ready_state: asyncio.Queue[Guild] = asyncio.Queue() self.clear(views=False) + self.clear_chunk_requests(None) self.user = user = ClientUser(state=self, data=data['user']) self._users[user.id] = user # type: ignore @@ -1204,7 +1217,9 @@ class ConnectionState(Generic[ClientT]): cache = cache or self.member_cache_flags.joined request = self._chunk_requests.get(guild.id) if request is None: - self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + self._chunk_requests[guild.id] = request = ChunkRequest( + guild.id, guild.shard_id, self.loop, self._get_guild, cache=cache + ) await self.chunker(guild.id, nonce=request.nonce) if wait: @@ -1751,6 +1766,7 @@ class AutoShardedConnectionState(ConnectionState[ClientT]): if shard_id in self._ready_tasks: self._ready_tasks[shard_id].cancel() + self.clear_chunk_requests(shard_id) if shard_id not in self._ready_states: self._ready_states[shard_id] = asyncio.Queue()