diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 91e49c3a7..dab4b5ea8 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -35,6 +35,9 @@ class Loop: websockets.WebSocketProtocolError, ) + self._before_loop = None + self._after_loop = None + if self.count is not None and self.count <= 0: raise ValueError('count must be greater than 0 or None.') @@ -47,25 +50,42 @@ class Loop: raise ValueError('Total number of seconds cannot be less than zero.') if not inspect.iscoroutinefunction(self.coro): - raise TypeError('Expected coroutine function, not {0!r}.'.format(type(self.coro))) + raise TypeError('Expected coroutine function, not {0.__name__!r}.'.format(type(self.coro))) - async def _loop(self, *args, **kwargs): - backoff = ExponentialBackoff() - while True: - try: - await self.coro(*args, **kwargs) - except asyncio.CancelledError: - return - except self._valid_exception as exc: - if not self.reconnect: - raise - await asyncio.sleep(backoff.delay()) + async def _call_loop_function(self, name): + coro = getattr(self, '_' + name) + if coro is None: + return + + if inspect.iscoroutinefunction(coro): + if self._injected is not None: + await coro(self._injected) else: - self._current_loop += 1 - if self._current_loop == self.count: - return + await coro() + else: + await coro - await asyncio.sleep(self._sleep) + async def _loop(self, *args, **kwargs): + backoff = ExponentialBackoff() + await self._call_loop_function('before_loop') + try: + while True: + try: + await self.coro(*args, **kwargs) + except asyncio.CancelledError: + break + except self._valid_exception as exc: + if not self.reconnect: + raise + await asyncio.sleep(backoff.delay()) + else: + self._current_loop += 1 + if self._current_loop == self.count: + break + + await asyncio.sleep(self._sleep) + finally: + await self._call_loop_function('after_loop') def __get__(self, obj, objtype): if obj is None: @@ -171,6 +191,49 @@ class Loop: """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" return self._task + def before_loop(self, coro): + """A function that also acts as a decorator to register a coroutine to be + called before the loop starts running. This is useful if you want to wait + for some bot state before the loop starts, + such as :meth:`discord.Client.wait_until_ready`. + + Parameters + ------------ + coro: :term:`py:awaitable` + The coroutine to register before the loop runs. + + Raises + ------- + TypeError + The function was not a coroutine. + """ + + if not (inspect.iscoroutinefunction(coro) or inspect.isawaitable(coro)): + raise TypeError('Expected coroutine or awaitable, received {0.__name__!r}.'.format(type(coro))) + + self._before_loop = coro + + + def after_loop(self, coro): + """A function that also acts as a decorator to register a coroutine to be + called after the loop finished running. + + Parameters + ------------ + coro: :term:`py:awaitable` + The coroutine to register after the loop finishes. + + Raises + ------- + TypeError + The function was not a coroutine. + """ + + if not (inspect.iscoroutinefunction(coro) or inspect.isawaitable(coro)): + raise TypeError('Expected coroutine or awaitable, received {0.__name__!r}.'.format(type(coro))) + + self._after_loop = coro + def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None): """A decorator that schedules a task in the background for you with optional reconnect logic. diff --git a/docs/ext/tasks/index.rst b/docs/ext/tasks/index.rst index 94d1320da..93e7b3f8f 100644 --- a/docs/ext/tasks/index.rst +++ b/docs/ext/tasks/index.rst @@ -66,14 +66,56 @@ Looping a certain amount of times before exiting: async def slow_count(): print(slow_count.current_loop) + @slow_count.after_loop + async def after_slow_count(): + print('done!') + slow_count.start() -Doing something after a task finishes is as simple as using :meth:`asyncio.Task.add_done_callback`: +Waiting until the bot is ready before the loop starts: .. code-block:: python3 - afterwards = lambda f: print('done!') - slow_count.get_task().add_done_callback(afterwards) + from discord.ext import tasks, commands + + class MyCog(commands.Cog): + def __init__(self, bot): + self.index = 0 + self.printer.before_loop(bot.wait_until_ready()) + self.printer.start() + + def cog_unload(self): + self.printer.cancel() + + @tasks.loop(seconds=5.0) + async def printer(self): + print(self.index) + self.index += 1 + +:meth:`~.tasks.Loop.before_loop` can be used as a decorator as well: + +.. code-block:: python3 + + from discord.ext import tasks, commands + + class MyCog(commands.Cog): + def __init__(self, bot): + self.index = 0 + self.bot = bot + self.printer.start() + + def cog_unload(self): + self.printer.cancel() + + @tasks.loop(seconds=5.0) + async def printer(self): + print(self.index) + self.index += 1 + + @printer.before_loop + async def before_printer(self): + print('waiting...') + await self.bot.wait_until_ready() API Reference ---------------