From 105c55a0b1ac3c699dca07951e9a1e9af09eae54 Mon Sep 17 00:00:00 2001 From: Michael H Date: Sat, 4 May 2024 23:20:36 -0400 Subject: [PATCH] Ensure Client.close() has finished in __aexit__ This wraps the closing behavior in a task. Subsequent callers of .close() now await that same close finishing rather than short circuiting. This prevents a user-called close outside of __aexit__ from not finishing before no longer having a running event loop. --- discord/client.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/discord/client.py b/discord/client.py index 9121617b3..49651d7e0 100644 --- a/discord/client.py +++ b/discord/client.py @@ -317,7 +317,7 @@ class Client: self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._sync_presences: bool = options.pop('sync_presence', True) self._connection: ConnectionState = self._get_state(**options) - self._closed: bool = False + self._closing_task: Optional[asyncio.Task[None]] = None self._ready: asyncio.Event = MISSING if VoiceClient.warn_nacl: @@ -334,7 +334,10 @@ class Client: exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - if not self.is_closed(): + # This avoids double-calling a user-provided .close() + if self._closing_task: + await self._closing_task + else: await self.close() # Internals @@ -969,27 +972,29 @@ class Client: Closes the connection to Discord. """ - if self._closed: - return + if self._closing_task: + return await self._closing_task - self._closed = True + async def _close(): + for voice in self.voice_clients: + try: + await voice.disconnect(force=True) + except Exception: + # If an error happens during disconnects, disregard it + pass - for voice in self.voice_clients: - try: - await voice.disconnect(force=True) - except Exception: - # If an error happens during disconnects, disregard it - pass + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) - if self.ws is not None and self.ws.open: - await self.ws.close(code=1000) + await self.http.close() - await self.http.close() + if self._ready is not MISSING: + self._ready.clear() - if self._ready is not MISSING: - self._ready.clear() + self.loop = MISSING - self.loop = MISSING + self._closing_task = asyncio.create_task(_close()) + await self._closing_task def clear(self) -> None: """Clears the internal state of the bot. @@ -998,7 +1003,7 @@ class Client: and :meth:`is_ready` both return ``False`` along with the bot's internal cache cleared. """ - self._closed = False + self._closing_task = None self._ready.clear() self._connection.clear(full=True) self.http.clear() @@ -1114,7 +1119,7 @@ class Client: def is_closed(self) -> bool: """:class:`bool`: Indicates if the websocket connection is closed.""" - return self._closed + return self._closing_task is not None @property def voice_client(self) -> Optional[VoiceProtocol]: