From c02a3c0bb244ccbf55d1c96fb59609f33eee0f27 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 13 Mar 2022 05:02:12 -0400 Subject: [PATCH] Add asynchronous context manager support for Client --- discord/client.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/discord/client.py b/discord/client.py index daac8a1ad..78242691e 100644 --- a/discord/client.py +++ b/discord/client.py @@ -41,6 +41,7 @@ from typing import ( Sequence, TYPE_CHECKING, Tuple, + Type, TypeVar, Union, ) @@ -76,6 +77,8 @@ from .threads import Thread from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory if TYPE_CHECKING: + from typing_extensions import Self + from types import TracebackType from .types.guild import Guild as GuildPayload from .abc import SnowflakeTime, Snowflake, PrivateChannel from .guild import GuildChannel @@ -180,10 +183,7 @@ class Client: The websocket gateway the client is currently connected to. Could be ``None``. """ - def __init__( - self, - **options: Any, - ): + def __init__(self, **options: Any) -> None: self.loop: asyncio.AbstractEventLoop = MISSING # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore @@ -216,6 +216,19 @@ class Client: VoiceClient.warn_nacl = False _log.warning("PyNaCl is not installed, voice will NOT be supported") + async def __aenter__(self) -> Self: + self.loop = asyncio.get_running_loop() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + if not self.is_closed(): + await self.close() + # internals def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: @@ -601,12 +614,8 @@ class Client: self.loop = asyncio.get_running_loop() self.http.loop = self.loop self._connection.loop = self.loop - try: - await self.login(token) - await self.connect(reconnect=reconnect) - finally: - if not self.is_closed(): - await self.close() + await self.login(token) + await self.connect(reconnect=reconnect) async def setup_hook(self) -> None: """|coro| @@ -645,8 +654,13 @@ class Client: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ + + async def runner(): + async with self: + await self.start(*args, **kwargs) + try: - asyncio.run(self.start(*args, **kwargs)) + asyncio.run(runner()) except KeyboardInterrupt: # nothing to do here # `asyncio.run` handles the loop cleanup