From 08a4db396118aeda6205ff56c8c8fc565fc338fc Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 11 Aug 2021 02:16:22 -0400 Subject: [PATCH] Revert "Refactor Client.run to use asyncio.run" This reverts commit 6e6c8a7b2810747222a938c7fe3e466c2994b23f. --- discord/client.py | 200 +++++++++++++++++++++++----------------------- discord/http.py | 3 +- 2 files changed, 103 insertions(+), 100 deletions(-) diff --git a/discord/client.py b/discord/client.py index 4279a36ba..cdfc6bdbc 100644 --- a/discord/client.py +++ b/discord/client.py @@ -26,24 +26,10 @@ from __future__ import annotations import asyncio import logging +import signal import sys import traceback -from typing import ( - Any, - Callable, - Coroutine, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - TYPE_CHECKING, - Tuple, - TypeVar, - Type, - Union, -) +from typing import Any, Callable, Coroutine, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union import aiohttp @@ -82,7 +68,6 @@ if TYPE_CHECKING: from .message import Message from .member import Member from .voice_client import VoiceProtocol - from types import TracebackType __all__ = ( 'Client', @@ -93,8 +78,36 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) log: logging.Logger = logging.getLogger(__name__) -C = TypeVar('C', bound='Client') - +def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: + tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} + + if not tasks: + return + + log.info('Cleaning up after %d tasks.', len(tasks)) + for task in tasks: + task.cancel() + + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + log.info('All tasks finished cancelling.') + + for task in tasks: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'Unhandled exception during Client.run shutdown.', + 'exception': task.exception(), + 'task': task + }) + +def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: + try: + _cancel_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + log.info('Closing the event loop.') + loop.close() class Client: r"""Represents a client connection that connects to Discord. @@ -187,7 +200,6 @@ class Client: loop: :class:`asyncio.AbstractEventLoop` The event loop that the client uses for asynchronous operations. """ - def __init__( self, *, @@ -195,8 +207,7 @@ class Client: **options: Any, ): self.ws: DiscordWebSocket = None # type: ignore - # this is filled in later - self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: Optional[int] = options.get('shard_id') self.shard_count: Optional[int] = options.get('shard_count') @@ -205,16 +216,14 @@ class Client: proxy: Optional[str] = options.pop('proxy', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) unsync_clock: bool = options.pop('assume_unsync_clock', True) - self.http: HTTPClient = HTTPClient( - connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=loop - ) + self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) self._handlers: Dict[str, Callable] = { - 'ready': self._handle_ready, + 'ready': self._handle_ready } self._hooks: Dict[str, Callable] = { - 'before_identify': self._call_before_identify_hook, + 'before_identify': self._call_before_identify_hook } self._enable_debug_events: bool = options.pop('enable_debug_events', False) @@ -235,9 +244,8 @@ class Client: return self.ws def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState( - dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options - ) + return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, + hooks=self._hooks, http=self.http, loop=self.loop, **options) def _handle_ready(self) -> None: self._ready.set() @@ -335,9 +343,7 @@ class Client: """:class:`bool`: Specifies if the client's internal cache is ready for use.""" return self._ready.is_set() - async def _run_event( - self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any - ) -> None: + async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -348,9 +354,7 @@ class Client: except asyncio.CancelledError: pass - def _schedule_event( - self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any - ) -> asyncio.Task: + def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') @@ -462,8 +466,7 @@ class Client: """ log.info('logging in using static token') - self.loop = loop = asyncio.get_running_loop() - self._connection.loop = loop + data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) @@ -509,14 +512,12 @@ class Client: self.dispatch('disconnect') ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) continue - except ( - OSError, - HTTPException, - GatewayNotFound, - ConnectionClosed, - aiohttp.ClientError, - asyncio.TimeoutError, - ) as exc: + except (OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError) as exc: self.dispatch('disconnect') if not reconnect: @@ -557,22 +558,6 @@ class Client: """|coro| Closes the connection to Discord. - - Instead of calling this directly, it is recommended to use the asynchronous context - manager to allow resources to be cleaned up automatically: - - .. code-block:: python3 - - async def main(): - async with Client() as client: - await client.login(token) - await client.connect() - - asyncio.run(main()) - - - .. versionchanged:: 2.0 - The client can now be closed with an asynchronous context manager """ if self._closed: return @@ -604,47 +589,36 @@ class Client: self._connection.clear() self.http.recreate() - async def __aenter__(self: C) -> C: - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - await self.close() - async def start(self, token: str, *, reconnect: bool = True) -> None: """|coro| - A shorthand function equivalent to the following: + A shorthand coroutine for :meth:`login` + :meth:`connect`. - .. code-block:: python3 - - async with client: - await client.login(token) - await client.connect() - - This closes the client when it returns. + Raises + ------- + TypeError + An unexpected keyword argument was received. """ - try: - await self.login(token) - await self.connect(reconnect=reconnect) - finally: - await self.close() + await self.login(token) + await self.connect(reconnect=reconnect) def run(self, *args: Any, **kwargs: Any) -> None: - """A convenience blocking call that abstracts away the event loop + """A blocking call that abstracts away the event loop initialisation from you. If you want more control over the event loop then this function should not be used. Use :meth:`start` coroutine or :meth:`connect` + :meth:`login`. - Equivalent to: :: + Roughly Equivalent to: :: - asyncio.run(bot.start(*args, **kwargs)) + try: + loop.run_until_complete(start(*args, **kwargs)) + except KeyboardInterrupt: + loop.run_until_complete(close()) + # cancel all tasks lingering + finally: + loop.close() .. warning:: @@ -652,7 +626,41 @@ class Client: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - asyncio.run(self.start(*args, **kwargs)) + loop = self.loop + + try: + loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) + loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) + except NotImplementedError: + pass + + async def runner(): + try: + await self.start(*args, **kwargs) + finally: + if not self.is_closed(): + await self.close() + + def stop_loop_on_completion(f): + loop.stop() + + future = asyncio.ensure_future(runner(), loop=loop) + future.add_done_callback(stop_loop_on_completion) + try: + loop.run_forever() + except KeyboardInterrupt: + log.info('Received signal to terminate bot and event loop.') + finally: + future.remove_done_callback(stop_loop_on_completion) + log.info('Cleaning up tasks.') + _cleanup_loop(loop) + + if not future.cancelled(): + try: + return future.result() + except KeyboardInterrupt: + # I am unsure why this gets raised here but suppress it anyway + return None # properties @@ -965,10 +973,8 @@ class Client: future = self.loop.create_future() if check is None: - def _check(*args): return True - check = _check ev = event.lower() @@ -1077,7 +1083,7 @@ class Client: *, limit: Optional[int] = 100, before: SnowflakeTime = None, - after: SnowflakeTime = None, + after: SnowflakeTime = None ) -> GuildIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. @@ -1157,7 +1163,7 @@ class Client: """ code = utils.resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return Template(data=data, state=self._connection) # type: ignore async def fetch_guild(self, guild_id: int) -> Guild: """|coro| @@ -1278,9 +1284,7 @@ class Client: # Invite management - async def fetch_invite( - self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True - ) -> Invite: + async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: """|coro| Gets an :class:`.Invite` from a discord.gg URL or ID. @@ -1516,7 +1520,7 @@ class Client: """ data = await self.http.get_sticker(sticker_id) cls, _ = _sticker_factory(data['type']) # type: ignore - return cls(state=self._connection, data=data) # type: ignore + return cls(state=self._connection, data=data) # type: ignore async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """|coro| diff --git a/discord/http.py b/discord/http.py index 5a31928b5..b186782ff 100644 --- a/discord/http.py +++ b/discord/http.py @@ -167,7 +167,7 @@ class HTTPClient: loop: Optional[asyncio.AbstractEventLoop] = None, unsync_clock: bool = True ) -> None: - self.loop: asyncio.AbstractEventLoop = MISSING if loop is None else loop # filled in static_login + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() @@ -371,7 +371,6 @@ class HTTPClient: async def static_login(self, token: str) -> user.User: # Necessary to get aiohttp to stop complaining about session creation - self.loop = asyncio.get_running_loop() self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) old_token = self.token self.token = token