3 changed files with 292 additions and 0 deletions
@ -0,0 +1,209 @@ |
|||||
|
import asyncio |
||||
|
import aiohttp |
||||
|
import websockets |
||||
|
import discord |
||||
|
import inspect |
||||
|
|
||||
|
from discord.backoff import ExponentialBackoff |
||||
|
|
||||
|
MAX_ASYNCIO_SECONDS = 3456000 |
||||
|
|
||||
|
class Loop: |
||||
|
"""A background task helper that abstracts the loop and reconnection logic for you. |
||||
|
|
||||
|
The main interface to create this is through :func:`loop`. |
||||
|
""" |
||||
|
def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop): |
||||
|
self.coro = coro |
||||
|
self.seconds = seconds |
||||
|
self.hours = hours |
||||
|
self.minutes = minutes |
||||
|
self.reconnect = reconnect |
||||
|
self.loop = loop or asyncio.get_event_loop() |
||||
|
self.count = count |
||||
|
self._current_loop = 0 |
||||
|
self._task = None |
||||
|
self._injected = None |
||||
|
self._valid_exception = ( |
||||
|
OSError, |
||||
|
discord.HTTPException, |
||||
|
discord.GatewayNotFound, |
||||
|
discord.ConnectionClosed, |
||||
|
aiohttp.ClientError, |
||||
|
asyncio.TimeoutError, |
||||
|
websockets.InvalidHandshake, |
||||
|
websockets.WebSocketProtocolError, |
||||
|
) |
||||
|
|
||||
|
if self.count is not None and self.count <= 0: |
||||
|
raise ValueError('count must be greater than 0 or None.') |
||||
|
|
||||
|
self._sleep = sleep = self.seconds + (self.minutes * 60.0) + (self.hours * 3600.0) |
||||
|
if sleep >= MAX_ASYNCIO_SECONDS: |
||||
|
raise ValueError('Total time exceeds asyncio imposed limit of {0} seconds.'.format(MAX_ASYNCIO_SECONDS)) |
||||
|
|
||||
|
if not inspect.iscoroutinefunction(self.coro): |
||||
|
raise TypeError('Expected coroutine function, not {0!r}.'.format(type(self.coro))) |
||||
|
|
||||
|
async def _loop(self, *args, **kwargs): |
||||
|
backoff = ExponentialBackoff() |
||||
|
while True: |
||||
|
try: |
||||
|
await self.coro(*args, **kwargs) |
||||
|
except asyncio.CancelledError: |
||||
|
return |
||||
|
except self._valid_exception as exc: |
||||
|
if not self.reconnect: |
||||
|
raise |
||||
|
await asyncio.sleep(backoff.delay()) |
||||
|
else: |
||||
|
self._current_loop += 1 |
||||
|
if self._current_loop == self.count: |
||||
|
return |
||||
|
|
||||
|
await asyncio.sleep(self._sleep) |
||||
|
|
||||
|
def __get__(self, obj, objtype): |
||||
|
if obj is None: |
||||
|
return self |
||||
|
self._injected = obj |
||||
|
return self |
||||
|
|
||||
|
@property |
||||
|
def current_loop(self): |
||||
|
""":class:`int`: The current iteration of the loop.""" |
||||
|
return self._current_loop |
||||
|
|
||||
|
|
||||
|
def run(self, *args, **kwargs): |
||||
|
r"""Runs the internal task in the event loop. |
||||
|
|
||||
|
Parameters |
||||
|
------------ |
||||
|
\*args |
||||
|
The arguments to to use. |
||||
|
\*\*kwargs |
||||
|
The keyword arguments to use. |
||||
|
|
||||
|
Raises |
||||
|
-------- |
||||
|
RuntimeError |
||||
|
A task has already been launched. |
||||
|
|
||||
|
Returns |
||||
|
--------- |
||||
|
:class:`asyncio.Task` |
||||
|
The task that has been registered. |
||||
|
""" |
||||
|
|
||||
|
if self._task is not None: |
||||
|
raise RuntimeError('Task is already launched.') |
||||
|
|
||||
|
if self._injected is not None: |
||||
|
args = (self._injected, *args) |
||||
|
|
||||
|
self._task = self.loop.create_task(self._loop(*args, **kwargs)) |
||||
|
return self._task |
||||
|
|
||||
|
def cancel(self): |
||||
|
"""Cancels the internal task, if any are running.""" |
||||
|
if self._task: |
||||
|
self._task.cancel() |
||||
|
|
||||
|
def add_exception_type(self, exc): |
||||
|
r"""Adds an exception type to be handled during the reconnect logic. |
||||
|
|
||||
|
By default the exception types handled are those handled by |
||||
|
:meth:`discord.Client.connect`\, which includes a lot of internet disconnection |
||||
|
errors. |
||||
|
|
||||
|
This function is useful if you're interacting with a 3rd party library that |
||||
|
raises its own set of exceptions. |
||||
|
|
||||
|
Parameters |
||||
|
------------ |
||||
|
exc: Type[:class:`BaseException`] |
||||
|
The exception class to handle. |
||||
|
|
||||
|
Raises |
||||
|
-------- |
||||
|
TypeError |
||||
|
The exception passed is either not a class or not inherited from :class:`BaseException`. |
||||
|
""" |
||||
|
|
||||
|
if not inspect.isclass(exc): |
||||
|
raise TypeError('{0!r} must be a class.'.format(exc)) |
||||
|
if not issubclass(exc, BaseException): |
||||
|
raise TypeError('{0!r} must inherit from BaseException.'.format(exc)) |
||||
|
|
||||
|
self._valid_exception = tuple(*self._valid_exception, exc) |
||||
|
|
||||
|
def clear_exception_types(self): |
||||
|
"""Removes all exception types that are handled. |
||||
|
|
||||
|
.. note:: |
||||
|
|
||||
|
This operation obviously cannot be undone! |
||||
|
""" |
||||
|
self._valid_exception = tuple() |
||||
|
|
||||
|
def remove_exception_type(self, exc): |
||||
|
"""Removes an exception type from being handled during the reconnect logic. |
||||
|
|
||||
|
Parameters |
||||
|
------------ |
||||
|
exc: Type[:class:`BaseException`] |
||||
|
The exception class to handle. |
||||
|
|
||||
|
Returns |
||||
|
--------- |
||||
|
:class:`bool` |
||||
|
Whether it was successfully removed. |
||||
|
""" |
||||
|
old_length = len(self._valid_exception) |
||||
|
self._valid_exception = tuple(x for x in self._valid_exception if x is not exc) |
||||
|
return len(self._valid_exception) != old_length |
||||
|
|
||||
|
def get_task(self): |
||||
|
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" |
||||
|
return self._task |
||||
|
|
||||
|
def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None): |
||||
|
"""A decorator that schedules a task in the background for you with |
||||
|
optional reconnect logic. |
||||
|
|
||||
|
Parameters |
||||
|
------------ |
||||
|
seconds: :class:`float` |
||||
|
The number of seconds between every iteration. |
||||
|
minutes: :class:`float` |
||||
|
The number of minutes between every iteration. |
||||
|
hours: :class:`float` |
||||
|
The number of hours between every iteration. |
||||
|
count: Optional[:class:`int`] |
||||
|
The number of loops to do, ``None`` if it should be an |
||||
|
infinite loop. |
||||
|
reconnect: :class:`bool` |
||||
|
Whether to handle errors and restart the task |
||||
|
using an exponential back-off algorithm similar to the |
||||
|
one used in :meth:`discord.Client.connect`. |
||||
|
loop: :class:`asyncio.AbstractEventLoop` |
||||
|
The loop to use to register the task, if not given |
||||
|
defaults to :func:`asyncio.get_event_loop`. |
||||
|
|
||||
|
Raises |
||||
|
-------- |
||||
|
ValueError |
||||
|
An invalid value was given. |
||||
|
TypeError |
||||
|
The function was not a coroutine. |
||||
|
|
||||
|
Returns |
||||
|
--------- |
||||
|
:class:`Loop` |
||||
|
The loop helper that handles the background task. |
||||
|
""" |
||||
|
def decorator(func): |
||||
|
return Loop(func, seconds=seconds, minutes=minutes, hours=hours, |
||||
|
count=count, reconnect=reconnect, loop=loop) |
||||
|
return decorator |
@ -0,0 +1,82 @@ |
|||||
|
``discord.ext.tasks`` -- asyncio.Task helpers |
||||
|
==================================================== |
||||
|
|
||||
|
One of the most common operations when making a bot is having a loop run in the background at a specified interval. This pattern is very common but has a lot of things you need to look out for: |
||||
|
|
||||
|
- How do I handle :exc:`asyncio.CancelledError`? |
||||
|
- What do I do if the internet goes out? |
||||
|
- What is the maximum number of seconds I can sleep anyway? |
||||
|
|
||||
|
The goal of this discord.py extension is to abstract all these worries away from you. |
||||
|
|
||||
|
Recipes |
||||
|
--------- |
||||
|
|
||||
|
A simple background task in a :class:`~discord.ext.commands.Cog`: |
||||
|
|
||||
|
.. code-block:: python3 |
||||
|
|
||||
|
from discord.ext import tasks, commands |
||||
|
|
||||
|
class MyCog(commands.Cog): |
||||
|
def __init__(self): |
||||
|
self.index = 0 |
||||
|
self.printer.run() |
||||
|
|
||||
|
def cog_unload(self): |
||||
|
self.printer.cancel() |
||||
|
|
||||
|
@tasks.loop(seconds=5.0) |
||||
|
async def printer(self): |
||||
|
print(self.index) |
||||
|
self.index += 1 |
||||
|
|
||||
|
Adding an exception to handle during reconnect: |
||||
|
|
||||
|
.. code-block:: python3 |
||||
|
|
||||
|
import asyncpg |
||||
|
from discord.ext import tasks, commands |
||||
|
|
||||
|
class MyCog(commands.Cog): |
||||
|
def __init__(self, bot): |
||||
|
self.bot = bot |
||||
|
self.data = [] |
||||
|
self.batch_update.add_exception_type(asyncpg.PostgresConnectionError) |
||||
|
self.batch_update.run() |
||||
|
|
||||
|
def cog_unload(self): |
||||
|
self.batch_update.cancel() |
||||
|
|
||||
|
@tasks.loop(minutes=5.0) |
||||
|
async def batch_update(self): |
||||
|
async with self.bot.pool.acquire() as con: |
||||
|
# batch update here... |
||||
|
pass |
||||
|
|
||||
|
Looping a certain amount of times before exiting: |
||||
|
|
||||
|
.. code-block:: python3 |
||||
|
|
||||
|
from discord.ext import tasks |
||||
|
|
||||
|
@tasks.loop(seconds=5.0, count=5) |
||||
|
async def slow_count(): |
||||
|
print(slow_count.current_loop) |
||||
|
|
||||
|
slow_count.run() |
||||
|
|
||||
|
Doing something after a task finishes is as simple as using :meth:`asyncio.Task.add_done_callback`: |
||||
|
|
||||
|
.. code-block:: python3 |
||||
|
|
||||
|
afterwards = lambda f: print('done!') |
||||
|
slow_count.get_task().add_done_callback(afterwards) |
||||
|
|
||||
|
API Reference |
||||
|
--------------- |
||||
|
|
||||
|
.. autoclass:: discord.ext.tasks.Loop() |
||||
|
:members: |
||||
|
|
||||
|
.. autofunction:: discord.ext.tasks.loop |
Loading…
Reference in new issue