Browse Source

Ensure Client.close() has finished in __aexit__

This wraps the closing behavior in a task. Subsequent callers of
.close() now await that same close finishing rather than short
circuiting. This prevents a user-called close outside of __aexit__ from
not finishing before no longer having a running event loop.
pull/9556/merge
Michael H 11 months ago
committed by GitHub
parent
commit
88f62d85d2
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 33
      discord/client.py
  2. 21
      discord/shard.py

33
discord/client.py

@ -287,7 +287,7 @@ class Client:
self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options) self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count
self._closed: bool = False self._closing_task: Optional[asyncio.Task[None]] = None
self._ready: asyncio.Event = MISSING self._ready: asyncio.Event = MISSING
self._application: Optional[AppInfo] = None self._application: Optional[AppInfo] = None
self._connection._get_websocket = self._get_websocket self._connection._get_websocket = self._get_websocket
@ -307,7 +307,10 @@ class Client:
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
) -> None: ) -> None:
if not self.is_closed(): # This avoids double-calling a user-provided .close()
if self._closing_task:
await self._closing_task
else:
await self.close() await self.close()
# internals # internals
@ -726,22 +729,24 @@ class Client:
Closes the connection to Discord. Closes the connection to Discord.
""" """
if self._closed: if self._closing_task:
return return await self._closing_task
self._closed = True async def _close():
await self._connection.close()
await self._connection.close() if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)
if self.ws is not None and self.ws.open: await self.http.close()
await self.ws.close(code=1000)
await self.http.close() if self._ready is not MISSING:
self._ready.clear()
if self._ready is not MISSING: self.loop = MISSING
self._ready.clear()
self.loop = MISSING self._closing_task = asyncio.create_task(_close())
await self._closing_task
def clear(self) -> None: def clear(self) -> None:
"""Clears the internal state of the bot. """Clears the internal state of the bot.
@ -750,7 +755,7 @@ class Client:
and :meth:`is_ready` both return ``False`` along with the bot's internal and :meth:`is_ready` both return ``False`` along with the bot's internal
cache cleared. cache cleared.
""" """
self._closed = False self._closing_task = None
self._ready.clear() self._ready.clear()
self._connection.clear() self._connection.clear()
self.http.clear() self.http.clear()
@ -870,7 +875,7 @@ class Client:
def is_closed(self) -> bool: def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed.""" """:class:`bool`: Indicates if the websocket connection is closed."""
return self._closed return self._closing_task is not None
@property @property
def activity(self) -> Optional[ActivityTypes]: def activity(self) -> Optional[ActivityTypes]:

21
discord/shard.py

@ -481,18 +481,21 @@ class AutoShardedClient(Client):
Closes the connection to Discord. Closes the connection to Discord.
""" """
if self.is_closed(): if self._closing_task:
return return await self._closing_task
async def _close():
await self._connection.close()
self._closed = True to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
await self._connection.close() if to_close:
await asyncio.wait(to_close)
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()] await self.http.close()
if to_close: self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
await asyncio.wait(to_close)
await self.http.close() self._closing_task = asyncio.create_task(_close())
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None)) await self._closing_task
async def change_presence( async def change_presence(
self, self,

Loading…
Cancel
Save