Browse Source

Add asynchronous context manager support for Client

pull/7674/head
Rapptz 3 years ago
parent
commit
c02a3c0bb2
  1. 36
      discord/client.py

36
discord/client.py

@ -41,6 +41,7 @@ from typing import (
Sequence,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
)
@ -76,6 +77,8 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, Snowflake, PrivateChannel
from .guild import GuildChannel
@ -180,10 +183,7 @@ class Client:
The websocket gateway the client is currently connected to. Could be ``None``.
"""
def __init__(
self,
**options: Any,
):
def __init__(self, **options: Any) -> None:
self.loop: asyncio.AbstractEventLoop = MISSING
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
@ -216,6 +216,19 @@ class Client:
VoiceClient.warn_nacl = False
_log.warning("PyNaCl is not installed, voice will NOT be supported")
async def __aenter__(self) -> Self:
self.loop = asyncio.get_running_loop()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
await self.close()
# internals
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
@ -601,12 +614,8 @@ class Client:
self.loop = asyncio.get_running_loop()
self.http.loop = self.loop
self._connection.loop = self.loop
try:
await self.login(token)
await self.connect(reconnect=reconnect)
finally:
if not self.is_closed():
await self.close()
await self.login(token)
await self.connect(reconnect=reconnect)
async def setup_hook(self) -> None:
"""|coro|
@ -645,8 +654,13 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
async def runner():
async with self:
await self.start(*args, **kwargs)
try:
asyncio.run(self.start(*args, **kwargs))
asyncio.run(runner())
except KeyboardInterrupt:
# nothing to do here
# `asyncio.run` handles the loop cleanup

Loading…
Cancel
Save