Browse Source

Add asynchronous context manager support for Client

pull/7674/head
Rapptz 3 years ago
parent
commit
c02a3c0bb2
  1. 36
      discord/client.py

36
discord/client.py

@ -41,6 +41,7 @@ from typing import (
Sequence, Sequence,
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -76,6 +77,8 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .abc import SnowflakeTime, Snowflake, PrivateChannel from .abc import SnowflakeTime, Snowflake, PrivateChannel
from .guild import GuildChannel from .guild import GuildChannel
@ -180,10 +183,7 @@ class Client:
The websocket gateway the client is currently connected to. Could be ``None``. The websocket gateway the client is currently connected to. Could be ``None``.
""" """
def __init__( def __init__(self, **options: Any) -> None:
self,
**options: Any,
):
self.loop: asyncio.AbstractEventLoop = MISSING self.loop: asyncio.AbstractEventLoop = MISSING
# self.ws is set in the connect method # self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore self.ws: DiscordWebSocket = None # type: ignore
@ -216,6 +216,19 @@ class Client:
VoiceClient.warn_nacl = False VoiceClient.warn_nacl = False
_log.warning("PyNaCl is not installed, voice will NOT be supported") _log.warning("PyNaCl is not installed, voice will NOT be supported")
async def __aenter__(self) -> Self:
self.loop = asyncio.get_running_loop()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
await self.close()
# internals # internals
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
@ -601,12 +614,8 @@ class Client:
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.http.loop = self.loop self.http.loop = self.loop
self._connection.loop = self.loop self._connection.loop = self.loop
try: await self.login(token)
await self.login(token) await self.connect(reconnect=reconnect)
await self.connect(reconnect=reconnect)
finally:
if not self.is_closed():
await self.close()
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""|coro| """|coro|
@ -645,8 +654,13 @@ class Client:
is blocking. That means that registration of events or anything being is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns. called after this function call will not execute until it returns.
""" """
async def runner():
async with self:
await self.start(*args, **kwargs)
try: try:
asyncio.run(self.start(*args, **kwargs)) asyncio.run(runner())
except KeyboardInterrupt: except KeyboardInterrupt:
# nothing to do here # nothing to do here
# `asyncio.run` handles the loop cleanup # `asyncio.run` handles the loop cleanup

Loading…
Cancel
Save