From 9c091e10445dcd69b3dcf3c939cb9faa470b9a58 Mon Sep 17 00:00:00 2001 From: Michael H Date: Wed, 26 Feb 2025 10:11:00 -0500 Subject: [PATCH] Ensure ready tasks are restarted as needed on resume --- discord/state.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/discord/state.py b/discord/state.py index 0fbeadea2..a502dadee 100644 --- a/discord/state.py +++ b/discord/state.py @@ -677,6 +677,19 @@ class ConnectionState(Generic[ClientT]): self.dispatch('connect') self._ready_task = asyncio.create_task(self._delay_ready()) + self._ready_task.add_done_callback(self._ready_resume_done_callback) + + def _ready_resume_done_callback(self, task: asyncio.Task): + # https://github.com/Rapptz/discord.py/issues/10118 + # We can get a resume during chunking, which results + # in attempting to write to a closing transport. + if task.cancelled(): + return + exc = task.exception() + if exc and isinstance(exc, ConnectionResetError): + _log.debug("Restarting delay ready due to connection reset") + self._ready_task = asyncio.create_task(self._delay_ready()) + self._ready_task.add_done_callback(self._ready_resume_done_callback) def parse_resumed(self, data: gw.ResumedEvent) -> None: self.dispatch('resumed') @@ -1977,6 +1990,29 @@ class AutoShardedConnectionState(ConnectionState[ClientT]): # The delay task for every shard has been started if len(self._ready_tasks) == len(self.shard_ids): self._ready_task = asyncio.create_task(self._delay_ready()) + self._ready_task.add_done_callback(self._ready_resume_done_callback) + + def _ready_resume_done_callback(self, task: asyncio.Task): + # https://github.com/Rapptz/discord.py/issues/10118 + # We can get a resume during chunking, which results + # in attempting to write to a closing transport. + if task.cancelled(): + return + exc = task.exception() + if exc and isinstance(exc, ConnectionResetError): + # This was raised up while gathering, find all tasks that need restarting + needs_restart = [ + shard_id + for shard_id, shard_ready in self._ready_tasks.items() + if (not shard_ready.cancelled()) and isinstance(shard_ready.exception(), ConnectionResetError) + ] + for shard_id in needs_restart: + _log.debug("Shard ID %s Restarting shard ready delay due to connection reset", shard_id) + self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id)) + + _log.debug("Restarting delay ready due to connection reset") + self._ready_task = asyncio.create_task(self._delay_ready()) + self._ready_task.add_done_callback(self._ready_resume_done_callback) def parse_resumed(self, data: gw.ResumedEvent) -> None: self.dispatch('resumed')