From 93af158b0ce72d8b53b13d6aead5be35cff41d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Han=20Seung=20Min=20-=20=ED=95=9C=EC=8A=B9=EB=AF=BC?= Date: Sun, 13 Mar 2022 14:24:14 +0530 Subject: [PATCH] Refactor loop code to allow usage of asyncio.run --- discord/client.py | 114 ++++++------------------------ discord/ext/commands/cooldowns.py | 2 +- discord/ext/tasks/__init__.py | 24 +------ discord/http.py | 12 ++-- discord/player.py | 8 +-- discord/shard.py | 8 +-- discord/state.py | 2 +- discord/ui/view.py | 6 +- discord/voice_client.py | 4 +- 9 files changed, 44 insertions(+), 136 deletions(-) diff --git a/discord/client.py b/discord/client.py index 8d8151469..daac8a1ad 100644 --- a/discord/client.py +++ b/discord/client.py @@ -27,7 +27,6 @@ from __future__ import annotations import asyncio import datetime import logging -import signal import sys import traceback from typing import ( @@ -97,41 +96,6 @@ Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) _log = logging.getLogger(__name__) -def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: - tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} - - if not tasks: - return - - _log.info('Cleaning up after %d tasks.', len(tasks)) - for task in tasks: - task.cancel() - - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - _log.info('All tasks finished cancelling.') - - for task in tasks: - if task.cancelled(): - continue - if task.exception() is not None: - loop.call_exception_handler( - { - 'message': 'Unhandled exception during Client.run shutdown.', - 'exception': task.exception(), - 'task': task, - } - ) - - -def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: - try: - _cancel_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - _log.info('Closing the event loop.') - loop.close() - - class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -146,12 +110,6 @@ class Client: .. versionchanged:: 1.3 Allow disabling the message cache and change the default size to ``1000``. - loop: Optional[:class:`asyncio.AbstractEventLoop`] - The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations. - Defaults to ``None``, in which case the default event loop is used via - :func:`asyncio.get_event_loop()`. - connector: Optional[:class:`aiohttp.BaseConnector`] - The connector to use for connection pooling. proxy: Optional[:class:`str`] Proxy URL. proxy_auth: Optional[:class:`aiohttp.BasicAuth`] @@ -220,30 +178,23 @@ class Client: ----------- ws The websocket gateway the client is currently connected to. Could be ``None``. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the client uses for asynchronous operations. """ def __init__( self, - *, - loop: Optional[asyncio.AbstractEventLoop] = None, **options: Any, ): + self.loop: asyncio.AbstractEventLoop = MISSING # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: Optional[int] = options.get('shard_id') self.shard_count: Optional[int] = options.get('shard_count') - connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None) proxy: Optional[str] = options.pop('proxy', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) unsync_clock: bool = options.pop('assume_unsync_clock', True) - self.http: HTTPClient = HTTPClient( - connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop - ) + self.http: HTTPClient = HTTPClient(self.loop, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock) self._handlers: Dict[str, Callable] = { 'ready': self._handle_ready, @@ -399,7 +350,7 @@ class Client: ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task - return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') + return self.loop.create_task(wrapped, name=f'discord.py: {event_name}') def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: _log.debug('Dispatching event %s', event) @@ -623,6 +574,7 @@ class Client: await self.http.close() self._ready.clear() + self.loop = MISSING def clear(self) -> None: """Clears the internal state of the bot. @@ -646,8 +598,15 @@ class Client: TypeError An unexpected keyword argument was received. """ - await self.login(token) - await self.connect(reconnect=reconnect) + self.loop = asyncio.get_running_loop() + self.http.loop = self.loop + self._connection.loop = self.loop + try: + await self.login(token) + await self.connect(reconnect=reconnect) + finally: + if not self.is_closed(): + await self.close() async def setup_hook(self) -> None: """|coro| @@ -676,12 +635,9 @@ class Client: Roughly Equivalent to: :: try: - loop.run_until_complete(start(*args, **kwargs)) + asyncio.run(self.start(*args, **kwargs)) except KeyboardInterrupt: - loop.run_until_complete(close()) - # cancel all tasks lingering - finally: - loop.close() + return .. warning:: @@ -689,41 +645,13 @@ class Client: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop - try: - loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) - loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) - except NotImplementedError: - pass - - async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() - - def stop_loop_on_completion(f): - loop.stop() - - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) - try: - loop.run_forever() + asyncio.run(self.start(*args, **kwargs)) except KeyboardInterrupt: - _log.info('Received signal to terminate bot and event loop.') - finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info('Cleaning up tasks.') - _cleanup_loop(loop) - - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + # nothing to do here + # `asyncio.run` handles the loop cleanup + # and `self.start` closes all sockets and the HTTPClient instance. + return # properties @@ -1324,7 +1252,7 @@ class Client: """ code = utils.resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return Template(data=data, state=self._connection) async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild: """|coro| diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index a66478b80..e188712b3 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -297,7 +297,7 @@ class _Semaphore: def __init__(self, number: int) -> None: self.value: int = number - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._waiters: Deque[asyncio.Future] = deque() def __repr__(self) -> str: diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index d90474370..e3254cebd 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -101,11 +101,9 @@ class Loop(Generic[LF]): time: Union[datetime.time, Sequence[datetime.time]], count: Optional[int], reconnect: bool, - loop: asyncio.AbstractEventLoop, ) -> None: self.coro: LF = coro self.reconnect: bool = reconnect - self.loop: asyncio.AbstractEventLoop = loop self.count: Optional[int] = count self._current_loop = 0 self._handle: Optional[SleepHandle] = None @@ -147,7 +145,7 @@ class Loop(Generic[LF]): await coro(*args, **kwargs) def _try_sleep_until(self, dt: datetime.datetime): - self._handle = SleepHandle(dt=dt, loop=self.loop) + self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop()) return self._handle.wait() async def _loop(self, *args: Any, **kwargs: Any) -> None: @@ -219,7 +217,6 @@ class Loop(Generic[LF]): time=self._time, count=self.count, reconnect=self.reconnect, - loop=self.loop, ) copy._injected = obj copy._before_loop = self._before_loop @@ -332,10 +329,7 @@ class Loop(Generic[LF]): if self._injected is not None: args = (self._injected, *args) - if self.loop is MISSING: - self.loop = asyncio.get_event_loop() - - self._task = self.loop.create_task(self._loop(*args, **kwargs)) + self._task = asyncio.create_task(self._loop(*args, **kwargs)) return self._task def stop(self) -> None: @@ -740,9 +734,6 @@ def loop( Whether to handle errors and restart the task using an exponential back-off algorithm similar to the one used in :meth:`discord.Client.connect`. - loop: :class:`asyncio.AbstractEventLoop` - The loop to use to register the task, if not given - defaults to :func:`asyncio.get_event_loop`. Raises -------- @@ -754,15 +745,6 @@ def loop( """ def decorator(func: LF) -> Loop[LF]: - return Loop[LF]( - func, - seconds=seconds, - minutes=minutes, - hours=hours, - count=count, - time=time, - reconnect=reconnect, - loop=loop, - ) + return Loop[LF](func, seconds=seconds, minutes=minutes, hours=hours, count=count, time=time, reconnect=reconnect) return decorator diff --git a/discord/http.py b/discord/http.py index 46c23a806..b6a83def2 100644 --- a/discord/http.py +++ b/discord/http.py @@ -332,15 +332,15 @@ class HTTPClient: def __init__( self, + loop: asyncio.AbstractEventLoop, connector: Optional[aiohttp.BaseConnector] = None, *, proxy: Optional[str] = None, proxy_auth: Optional[aiohttp.BasicAuth] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop - self.connector: aiohttp.BaseConnector = connector or aiohttp.TCPConnector(limit=0) + self.loop: asyncio.AbstractEventLoop = loop + self.connector: aiohttp.BaseConnector = connector or MISSING self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self._global_over: asyncio.Event = asyncio.Event() @@ -544,7 +544,11 @@ class HTTPClient: async def static_login(self, token: str) -> user.User: # Necessary to get aiohttp to stop complaining about session creation - self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) + if self.connector is MISSING: + self.connector = aiohttp.TCPConnector(loop=self.loop, limit=0) + self.__session = aiohttp.ClientSession( + connector=self.connector, ws_response_class=DiscordClientWebSocketResponse, loop=self.loop + ) old_token = self.token self.token = token diff --git a/discord/player.py b/discord/player.py index 6337ed7dc..cf215756b 100644 --- a/discord/player.py +++ b/discord/player.py @@ -521,9 +521,9 @@ class FFmpegOpusAudio(FFmpegAudio): raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'") codec = bitrate = None - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: - codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore + codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) except Exception: if not fallback: _log.exception("Probe '%s' using '%s' failed", method, executable) @@ -531,7 +531,7 @@ class FFmpegOpusAudio(FFmpegAudio): _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) try: - codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore + codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) except Exception: _log.exception("Fallback probe using '%s' failed", executable) else: @@ -744,6 +744,6 @@ class AudioPlayer(threading.Thread): def _speak(self, speaking: SpeakingState) -> None: try: - asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) + asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop) except Exception as e: _log.info("Speaking call in player failed: %s", e) diff --git a/discord/shard.py b/discord/shard.py index aa37a2c54..4ba521965 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -95,7 +95,6 @@ class Shard: self._client: Client = client self._dispatch: Callable[..., None] = client.dispatch self._queue_put: Callable[[EventItem], None] = queue_put - self.loop: asyncio.AbstractEventLoop = self._client.loop self._disconnect: bool = False self._reconnect = client._reconnect self._backoff: ExponentialBackoff = ExponentialBackoff() @@ -115,7 +114,7 @@ class Shard: return self.ws.shard_id # type: ignore def launch(self) -> None: - self._task = self.loop.create_task(self.worker()) + self._task = self._client.loop.create_task(self.worker()) def _cancel_task(self) -> None: if self._task is not None and not self._task.done(): @@ -318,10 +317,10 @@ class AutoShardedClient(Client): if TYPE_CHECKING: _connection: AutoShardedConnectionState - def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.pop('shard_id', None) self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) - super().__init__(*args, loop=loop, **kwargs) + super().__init__(*args, **kwargs) if self.shard_ids is not None: if self.shard_count is None: @@ -348,7 +347,6 @@ class AutoShardedClient(Client): handlers=self._handlers, hooks=self._hooks, http=self.http, - loop=self.loop, **options, ) diff --git a/discord/state.py b/discord/state.py index 29f09141b..3053df980 100644 --- a/discord/state.py +++ b/discord/state.py @@ -569,7 +569,7 @@ class ConnectionState: else: self.application_id = utils._get_as_snowflake(application, 'id') # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore + self.application_flags = ApplicationFlags._from_value(application['flags']) for guild_data in data['guilds']: self._add_guild_from_data(guild_data) # type: ignore diff --git a/discord/ui/view.py b/discord/ui/view.py index 979eff6c1..cbef4f7ef 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -177,12 +177,11 @@ class View: self.timeout = timeout self.children: List[Item] = self._init_children() self.__weights = _ViewWeights(self.children) - loop = asyncio.get_running_loop() self.id: str = os.urandom(16).hex() self.__cancel_callback: Optional[Callable[[View], None]] = None self.__timeout_expiry: Optional[float] = None self.__timeout_task: Optional[asyncio.Task[None]] = None - self.__stopped: asyncio.Future[bool] = loop.create_future() + self.__stopped: asyncio.Future[bool] = asyncio.get_running_loop().create_future() def __repr__(self) -> str: return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>' @@ -379,12 +378,11 @@ class View: def _start_listening_from_store(self, store: ViewStore) -> None: self.__cancel_callback = partial(store.remove_view) if self.timeout: - loop = asyncio.get_running_loop() if self.__timeout_task is not None: self.__timeout_task.cancel() self.__timeout_expiry = time.monotonic() + self.timeout - self.__timeout_task = loop.create_task(self.__timeout_task_impl()) + self.__timeout_task = asyncio.create_task(self.__timeout_task_impl()) def _dispatch_timeout(self): if self.__stopped.done(): diff --git a/discord/voice_client.py b/discord/voice_client.py index 37a268ba4..208bb78d3 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -222,8 +222,6 @@ class VoiceClient(VoiceProtocol): The endpoint we are connecting to. channel: Union[:class:`VoiceChannel`, :class:`StageChannel`] The voice channel connected to. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the voice client is running on. """ channel: VocalGuildChannel @@ -405,7 +403,7 @@ class VoiceClient(VoiceProtocol): raise if self._runner is MISSING: - self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) + self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect)) async def potential_reconnect(self) -> bool: # Attempt to stop the player thread from playing early