diff --git a/discord/client.py b/discord/client.py index cdfc6bdbc..4279a36ba 100644 --- a/discord/client.py +++ b/discord/client.py @@ -26,10 +26,24 @@ from __future__ import annotations import asyncio import logging -import signal import sys import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + TypeVar, + Type, + Union, +) import aiohttp @@ -68,6 +82,7 @@ if TYPE_CHECKING: from .message import Message from .member import Member from .voice_client import VoiceProtocol + from types import TracebackType __all__ = ( 'Client', @@ -78,36 +93,8 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) log: logging.Logger = logging.getLogger(__name__) -def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: - tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} - - if not tasks: - return - - log.info('Cleaning up after %d tasks.', len(tasks)) - for task in tasks: - task.cancel() - - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - log.info('All tasks finished cancelling.') - - for task in tasks: - if task.cancelled(): - continue - if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'Unhandled exception during Client.run shutdown.', - 'exception': task.exception(), - 'task': task - }) - -def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: - try: - _cancel_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - log.info('Closing the event loop.') - loop.close() +C = TypeVar('C', bound='Client') + class Client: r"""Represents a client connection that connects to Discord. @@ -200,6 +187,7 @@ class Client: loop: :class:`asyncio.AbstractEventLoop` The event loop that the client uses for asynchronous operations. """ + def __init__( self, *, @@ -207,7 +195,8 @@ class Client: **options: Any, ): self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + # this is filled in later + self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: Optional[int] = options.get('shard_id') self.shard_count: Optional[int] = options.get('shard_count') @@ -216,14 +205,16 @@ class Client: proxy: Optional[str] = options.pop('proxy', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) unsync_clock: bool = options.pop('assume_unsync_clock', True) - self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) + self.http: HTTPClient = HTTPClient( + connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop + ) self._handlers: Dict[str, Callable] = { - 'ready': self._handle_ready + 'ready': self._handle_ready, } self._hooks: Dict[str, Callable] = { - 'before_identify': self._call_before_identify_hook + 'before_identify': self._call_before_identify_hook, } self._enable_debug_events: bool = options.pop('enable_debug_events', False) @@ -244,8 +235,9 @@ class Client: return self.ws def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, - hooks=self._hooks, http=self.http, loop=self.loop, **options) + return ConnectionState( + dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options + ) def _handle_ready(self) -> None: self._ready.set() @@ -343,7 +335,9 @@ class Client: """:class:`bool`: Specifies if the client's internal cache is ready for use.""" return self._ready.is_set() - async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: + async def _run_event( + self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any + ) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -354,7 +348,9 @@ class Client: except asyncio.CancelledError: pass - def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: + def _schedule_event( + self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any + ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') @@ -466,7 +462,8 @@ class Client: """ log.info('logging in using static token') - + self.loop = loop = asyncio.get_running_loop() + self._connection.loop = loop data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) @@ -512,12 +509,14 @@ class Client: self.dispatch('disconnect') ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) continue - except (OSError, - HTTPException, - GatewayNotFound, - ConnectionClosed, - aiohttp.ClientError, - asyncio.TimeoutError) as exc: + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) as exc: self.dispatch('disconnect') if not reconnect: @@ -558,6 +557,22 @@ class Client: """|coro| Closes the connection to Discord. + + Instead of calling this directly, it is recommended to use the asynchronous context + manager to allow resources to be cleaned up automatically: + + .. code-block:: python3 + + async def main(): + async with Client() as client: + await client.login(token) + await client.connect() + + asyncio.run(main()) + + + .. versionchanged:: 2.0 + The client can now be closed with an asynchronous context manager """ if self._closed: return @@ -589,36 +604,47 @@ class Client: self._connection.clear() self.http.recreate() + async def __aenter__(self: C) -> C: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.close() + async def start(self, token: str, *, reconnect: bool = True) -> None: """|coro| - A shorthand coroutine for :meth:`login` + :meth:`connect`. + A shorthand function equivalent to the following: - Raises - ------- - TypeError - An unexpected keyword argument was received. + .. code-block:: python3 + + async with client: + await client.login(token) + await client.connect() + + This closes the client when it returns. """ - await self.login(token) - await self.connect(reconnect=reconnect) + try: + await self.login(token) + await self.connect(reconnect=reconnect) + finally: + await self.close() def run(self, *args: Any, **kwargs: Any) -> None: - """A blocking call that abstracts away the event loop + """A convenience blocking call that abstracts away the event loop initialisation from you. If you want more control over the event loop then this function should not be used. Use :meth:`start` coroutine or :meth:`connect` + :meth:`login`. - Roughly Equivalent to: :: + Equivalent to: :: - try: - loop.run_until_complete(start(*args, **kwargs)) - except KeyboardInterrupt: - loop.run_until_complete(close()) - # cancel all tasks lingering - finally: - loop.close() + asyncio.run(bot.start(*args, **kwargs)) .. warning:: @@ -626,41 +652,7 @@ class Client: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop - - try: - loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) - loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) - except NotImplementedError: - pass - - async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() - - def stop_loop_on_completion(f): - loop.stop() - - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) - try: - loop.run_forever() - except KeyboardInterrupt: - log.info('Received signal to terminate bot and event loop.') - finally: - future.remove_done_callback(stop_loop_on_completion) - log.info('Cleaning up tasks.') - _cleanup_loop(loop) - - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + asyncio.run(self.start(*args, **kwargs)) # properties @@ -973,8 +965,10 @@ class Client: future = self.loop.create_future() if check is None: + def _check(*args): return True + check = _check ev = event.lower() @@ -1083,7 +1077,7 @@ class Client: *, limit: Optional[int] = 100, before: SnowflakeTime = None, - after: SnowflakeTime = None + after: SnowflakeTime = None, ) -> GuildIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. @@ -1163,7 +1157,7 @@ class Client: """ code = utils.resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return Template(data=data, state=self._connection) # type: ignore async def fetch_guild(self, guild_id: int) -> Guild: """|coro| @@ -1284,7 +1278,9 @@ class Client: # Invite management - async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: + async def fetch_invite( + self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True + ) -> Invite: """|coro| Gets an :class:`.Invite` from a discord.gg URL or ID. @@ -1520,7 +1516,7 @@ class Client: """ data = await self.http.get_sticker(sticker_id) cls, _ = _sticker_factory(data['type']) # type: ignore - return cls(state=self._connection, data=data) # type: ignore + return cls(state=self._connection, data=data) # type: ignore async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """|coro| diff --git a/discord/http.py b/discord/http.py index b186782ff..5a31928b5 100644 --- a/discord/http.py +++ b/discord/http.py @@ -167,7 +167,7 @@ class HTTPClient: loop: Optional[asyncio.AbstractEventLoop] = None, unsync_clock: bool = True ) -> None: - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() @@ -371,6 +371,7 @@ class HTTPClient: async def static_login(self, token: str) -> user.User: # Necessary to get aiohttp to stop complaining about session creation + self.loop = asyncio.get_running_loop() self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) old_token = self.token self.token = token