3 changed files with 292 additions and 0 deletions
@ -0,0 +1,209 @@ |
|||
import asyncio |
|||
import aiohttp |
|||
import websockets |
|||
import discord |
|||
import inspect |
|||
|
|||
from discord.backoff import ExponentialBackoff |
|||
|
|||
MAX_ASYNCIO_SECONDS = 3456000 |
|||
|
|||
class Loop: |
|||
"""A background task helper that abstracts the loop and reconnection logic for you. |
|||
|
|||
The main interface to create this is through :func:`loop`. |
|||
""" |
|||
def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop): |
|||
self.coro = coro |
|||
self.seconds = seconds |
|||
self.hours = hours |
|||
self.minutes = minutes |
|||
self.reconnect = reconnect |
|||
self.loop = loop or asyncio.get_event_loop() |
|||
self.count = count |
|||
self._current_loop = 0 |
|||
self._task = None |
|||
self._injected = None |
|||
self._valid_exception = ( |
|||
OSError, |
|||
discord.HTTPException, |
|||
discord.GatewayNotFound, |
|||
discord.ConnectionClosed, |
|||
aiohttp.ClientError, |
|||
asyncio.TimeoutError, |
|||
websockets.InvalidHandshake, |
|||
websockets.WebSocketProtocolError, |
|||
) |
|||
|
|||
if self.count is not None and self.count <= 0: |
|||
raise ValueError('count must be greater than 0 or None.') |
|||
|
|||
self._sleep = sleep = self.seconds + (self.minutes * 60.0) + (self.hours * 3600.0) |
|||
if sleep >= MAX_ASYNCIO_SECONDS: |
|||
raise ValueError('Total time exceeds asyncio imposed limit of {0} seconds.'.format(MAX_ASYNCIO_SECONDS)) |
|||
|
|||
if not inspect.iscoroutinefunction(self.coro): |
|||
raise TypeError('Expected coroutine function, not {0!r}.'.format(type(self.coro))) |
|||
|
|||
async def _loop(self, *args, **kwargs): |
|||
backoff = ExponentialBackoff() |
|||
while True: |
|||
try: |
|||
await self.coro(*args, **kwargs) |
|||
except asyncio.CancelledError: |
|||
return |
|||
except self._valid_exception as exc: |
|||
if not self.reconnect: |
|||
raise |
|||
await asyncio.sleep(backoff.delay()) |
|||
else: |
|||
self._current_loop += 1 |
|||
if self._current_loop == self.count: |
|||
return |
|||
|
|||
await asyncio.sleep(self._sleep) |
|||
|
|||
def __get__(self, obj, objtype): |
|||
if obj is None: |
|||
return self |
|||
self._injected = obj |
|||
return self |
|||
|
|||
@property |
|||
def current_loop(self): |
|||
""":class:`int`: The current iteration of the loop.""" |
|||
return self._current_loop |
|||
|
|||
|
|||
def run(self, *args, **kwargs): |
|||
r"""Runs the internal task in the event loop. |
|||
|
|||
Parameters |
|||
------------ |
|||
\*args |
|||
The arguments to to use. |
|||
\*\*kwargs |
|||
The keyword arguments to use. |
|||
|
|||
Raises |
|||
-------- |
|||
RuntimeError |
|||
A task has already been launched. |
|||
|
|||
Returns |
|||
--------- |
|||
:class:`asyncio.Task` |
|||
The task that has been registered. |
|||
""" |
|||
|
|||
if self._task is not None: |
|||
raise RuntimeError('Task is already launched.') |
|||
|
|||
if self._injected is not None: |
|||
args = (self._injected, *args) |
|||
|
|||
self._task = self.loop.create_task(self._loop(*args, **kwargs)) |
|||
return self._task |
|||
|
|||
def cancel(self): |
|||
"""Cancels the internal task, if any are running.""" |
|||
if self._task: |
|||
self._task.cancel() |
|||
|
|||
def add_exception_type(self, exc): |
|||
r"""Adds an exception type to be handled during the reconnect logic. |
|||
|
|||
By default the exception types handled are those handled by |
|||
:meth:`discord.Client.connect`\, which includes a lot of internet disconnection |
|||
errors. |
|||
|
|||
This function is useful if you're interacting with a 3rd party library that |
|||
raises its own set of exceptions. |
|||
|
|||
Parameters |
|||
------------ |
|||
exc: Type[:class:`BaseException`] |
|||
The exception class to handle. |
|||
|
|||
Raises |
|||
-------- |
|||
TypeError |
|||
The exception passed is either not a class or not inherited from :class:`BaseException`. |
|||
""" |
|||
|
|||
if not inspect.isclass(exc): |
|||
raise TypeError('{0!r} must be a class.'.format(exc)) |
|||
if not issubclass(exc, BaseException): |
|||
raise TypeError('{0!r} must inherit from BaseException.'.format(exc)) |
|||
|
|||
self._valid_exception = tuple(*self._valid_exception, exc) |
|||
|
|||
def clear_exception_types(self): |
|||
"""Removes all exception types that are handled. |
|||
|
|||
.. note:: |
|||
|
|||
This operation obviously cannot be undone! |
|||
""" |
|||
self._valid_exception = tuple() |
|||
|
|||
def remove_exception_type(self, exc): |
|||
"""Removes an exception type from being handled during the reconnect logic. |
|||
|
|||
Parameters |
|||
------------ |
|||
exc: Type[:class:`BaseException`] |
|||
The exception class to handle. |
|||
|
|||
Returns |
|||
--------- |
|||
:class:`bool` |
|||
Whether it was successfully removed. |
|||
""" |
|||
old_length = len(self._valid_exception) |
|||
self._valid_exception = tuple(x for x in self._valid_exception if x is not exc) |
|||
return len(self._valid_exception) != old_length |
|||
|
|||
def get_task(self): |
|||
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" |
|||
return self._task |
|||
|
|||
def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None): |
|||
"""A decorator that schedules a task in the background for you with |
|||
optional reconnect logic. |
|||
|
|||
Parameters |
|||
------------ |
|||
seconds: :class:`float` |
|||
The number of seconds between every iteration. |
|||
minutes: :class:`float` |
|||
The number of minutes between every iteration. |
|||
hours: :class:`float` |
|||
The number of hours between every iteration. |
|||
count: Optional[:class:`int`] |
|||
The number of loops to do, ``None`` if it should be an |
|||
infinite loop. |
|||
reconnect: :class:`bool` |
|||
Whether to handle errors and restart the task |
|||
using an exponential back-off algorithm similar to the |
|||
one used in :meth:`discord.Client.connect`. |
|||
loop: :class:`asyncio.AbstractEventLoop` |
|||
The loop to use to register the task, if not given |
|||
defaults to :func:`asyncio.get_event_loop`. |
|||
|
|||
Raises |
|||
-------- |
|||
ValueError |
|||
An invalid value was given. |
|||
TypeError |
|||
The function was not a coroutine. |
|||
|
|||
Returns |
|||
--------- |
|||
:class:`Loop` |
|||
The loop helper that handles the background task. |
|||
""" |
|||
def decorator(func): |
|||
return Loop(func, seconds=seconds, minutes=minutes, hours=hours, |
|||
count=count, reconnect=reconnect, loop=loop) |
|||
return decorator |
@ -0,0 +1,82 @@ |
|||
``discord.ext.tasks`` -- asyncio.Task helpers |
|||
==================================================== |
|||
|
|||
One of the most common operations when making a bot is having a loop run in the background at a specified interval. This pattern is very common but has a lot of things you need to look out for: |
|||
|
|||
- How do I handle :exc:`asyncio.CancelledError`? |
|||
- What do I do if the internet goes out? |
|||
- What is the maximum number of seconds I can sleep anyway? |
|||
|
|||
The goal of this discord.py extension is to abstract all these worries away from you. |
|||
|
|||
Recipes |
|||
--------- |
|||
|
|||
A simple background task in a :class:`~discord.ext.commands.Cog`: |
|||
|
|||
.. code-block:: python3 |
|||
|
|||
from discord.ext import tasks, commands |
|||
|
|||
class MyCog(commands.Cog): |
|||
def __init__(self): |
|||
self.index = 0 |
|||
self.printer.run() |
|||
|
|||
def cog_unload(self): |
|||
self.printer.cancel() |
|||
|
|||
@tasks.loop(seconds=5.0) |
|||
async def printer(self): |
|||
print(self.index) |
|||
self.index += 1 |
|||
|
|||
Adding an exception to handle during reconnect: |
|||
|
|||
.. code-block:: python3 |
|||
|
|||
import asyncpg |
|||
from discord.ext import tasks, commands |
|||
|
|||
class MyCog(commands.Cog): |
|||
def __init__(self, bot): |
|||
self.bot = bot |
|||
self.data = [] |
|||
self.batch_update.add_exception_type(asyncpg.PostgresConnectionError) |
|||
self.batch_update.run() |
|||
|
|||
def cog_unload(self): |
|||
self.batch_update.cancel() |
|||
|
|||
@tasks.loop(minutes=5.0) |
|||
async def batch_update(self): |
|||
async with self.bot.pool.acquire() as con: |
|||
# batch update here... |
|||
pass |
|||
|
|||
Looping a certain amount of times before exiting: |
|||
|
|||
.. code-block:: python3 |
|||
|
|||
from discord.ext import tasks |
|||
|
|||
@tasks.loop(seconds=5.0, count=5) |
|||
async def slow_count(): |
|||
print(slow_count.current_loop) |
|||
|
|||
slow_count.run() |
|||
|
|||
Doing something after a task finishes is as simple as using :meth:`asyncio.Task.add_done_callback`: |
|||
|
|||
.. code-block:: python3 |
|||
|
|||
afterwards = lambda f: print('done!') |
|||
slow_count.get_task().add_done_callback(afterwards) |
|||
|
|||
API Reference |
|||
--------------- |
|||
|
|||
.. autoclass:: discord.ext.tasks.Loop() |
|||
:members: |
|||
|
|||
.. autofunction:: discord.ext.tasks.loop |
Loading…
Reference in new issue