diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 89070594a..bba436d45 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -5,6 +5,8 @@ import websockets import discord import inspect import logging +import sys +import traceback from discord.backoff import ExponentialBackoff @@ -50,15 +52,15 @@ class Loop: if not inspect.iscoroutinefunction(self.coro): raise TypeError('Expected coroutine function, not {0.__name__!r}.'.format(type(self.coro))) - async def _call_loop_function(self, name): + async def _call_loop_function(self, name, *args, **kwargs): coro = getattr(self, '_' + name) if coro is None: return if self._injected is not None: - await coro(self._injected) + await coro(self._injected, *args, **kwargs) else: - await coro() + await coro(*args, **kwargs) async def _loop(self, *args, **kwargs): backoff = ExponentialBackoff() @@ -89,10 +91,10 @@ class Loop: except asyncio.CancelledError: self._is_being_cancelled = True raise - except Exception: + except Exception as exc: self._has_failed = True - log.exception('Internal background task failed.') - raise + await self._call_loop_function('error', exc) + raise exc finally: await self._call_loop_function('after_loop') self._is_being_cancelled = False @@ -283,6 +285,10 @@ class Loop: """ return not bool(self._task.done()) if self._task else False + async def _error(self, exception): + print('Unhandled exception in internal background task {0.__name__!r}.'.format(self.coro), file=sys.stderr) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) + def before_loop(self, coro): """A decorator that registers a coroutine to be called before the loop starts running. @@ -336,6 +342,32 @@ class Loop: self._after_loop = coro return coro + def error(self, coro): + """A decorator that registers a coroutine to be called if the task encounters an unhandled exception. + + The coroutine must take only one argument the exception raised (except ``self`` in a class context). + + By default this prints to :data:`sys.stderr` however it could be + overridden to have a different implementation. + + .. versionadded:: 1.4 + + Parameters + ------------ + coro: :ref:`coroutine ` + The coroutine to register in the event of an unhandled exception. + + Raises + ------- + TypeError + The function was not a coroutine. + """ + if not inspect.iscoroutinefunction(coro): + raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro))) + + self._error = coro + return coro + def _get_next_sleep_time(self): return self._last_iteration + datetime.timedelta(seconds=self._sleep)