diff --git a/discord/client.py b/discord/client.py index 3ad4f7ae0..0829f1547 100644 --- a/discord/client.py +++ b/discord/client.py @@ -55,6 +55,46 @@ from .appinfo import AppInfo log = logging.getLogger(__name__) +class _ProperCleanup(Exception): + pass + +def _raise_proper_cleanup(): + raise _ProperCleanup + +def _cancel_tasks(loop, tasks): + if not tasks: + return + + log.info('Cleaning up after %d tasks.', len(tasks)) + gathered = asyncio.gather(*tasks, loop=loop, return_exceptions=True) + gathered.cancel() + gathered.add_done_callback(lambda fut: loop.stop()) + + while not gathered.done(): + loop.run_forever() + + for task in tasks: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'Unhandled exception during Client.run shutdown.', + 'exception': task.exception(), + 'task': task + }) + +def _cleanup_loop(loop): + try: + task_retriever = asyncio.Task.all_tasks + except AttributeError: + # future proofing for 3.9 I guess + task_retriever = asyncio.all_tasks + + all_tasks = {t for t in task_retriever(loop=loop) if not t.done()} + _cancel_tasks(loop, all_tasks) + if sys.version_info >= (3, 6): + loop.run_until_complete(loop.shutdown_asyncgens()) + class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -463,43 +503,28 @@ class Client: def _do_cleanup(self): log.info('Cleaning up event loop.') loop = self.loop - if loop.is_closed(): - return # we're already cleaning up task = asyncio.ensure_future(self.close(), loop=loop) - def _silence_gathered(fut): + def stop_loop(fut): try: fut.result() - except Exception: + except asyncio.CancelledError: pass + except Exception as e: + loop.call_exception_handler({ + 'message': 'Unexpected exception during Client.close', + 'exception': e + }) finally: loop.stop() - def when_future_is_done(fut): - pending = asyncio.Task.all_tasks(loop=loop) - if pending: - log.info('Cleaning up after %s tasks', len(pending)) - gathered = asyncio.gather(*pending, loop=loop) - gathered.cancel() - gathered.add_done_callback(_silence_gathered) - else: - loop.stop() - - task.add_done_callback(when_future_is_done) - if not loop.is_running(): - loop.run_forever() - else: - # on Linux, we're still running because we got triggered via - # the signal handler rather than the natural KeyboardInterrupt - # Since that's the case, we're going to return control after - # registering the task for the event loop to handle later - return None - + task.add_done_callback(stop_loop) try: - return task.result() # suppress unused task warning - except Exception: - return None + loop.run_forever() + finally: + _cleanup_loop(loop) + loop.close() def run(self, *args, **kwargs): """A blocking call that abstracts away the `event loop`_ @@ -528,29 +553,15 @@ class Client: is_windows = sys.platform == 'win32' loop = self.loop if not is_windows: - loop.add_signal_handler(signal.SIGINT, self._do_cleanup) - loop.add_signal_handler(signal.SIGTERM, self._do_cleanup) - - task = asyncio.ensure_future(self.start(*args, **kwargs), loop=loop) - - def stop_loop_on_finish(fut): - loop.stop() - - task.add_done_callback(stop_loop_on_finish) + loop.add_signal_handler(signal.SIGINT, _raise_proper_cleanup) + loop.add_signal_handler(signal.SIGTERM, _raise_proper_cleanup) try: - loop.run_forever() - except KeyboardInterrupt: + loop.run_until_complete(self.start(*args, **kwargs)) + except (_ProperCleanup, KeyboardInterrupt): log.info('Received signal to terminate bot and event loop.') finally: - task.remove_done_callback(stop_loop_on_finish) - if is_windows: - self._do_cleanup() - - loop.close() - if task.cancelled() or not task.done(): - return None - return task.result() + self._do_cleanup() # properties