|
|
@ -22,8 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
|
|
|
DEALINGS IN THE SOFTWARE. |
|
|
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import datetime |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Awaitable, |
|
|
|
Callable, |
|
|
|
Generic, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
Type, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
cast, |
|
|
|
) |
|
|
|
|
|
|
|
import aiohttp |
|
|
|
import discord |
|
|
|
import inspect |
|
|
@ -33,6 +48,7 @@ import traceback |
|
|
|
|
|
|
|
from collections.abc import Sequence |
|
|
|
from discord.backoff import ExponentialBackoff |
|
|
|
from discord.utils import MISSING |
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
@ -40,41 +56,58 @@ __all__ = ( |
|
|
|
'loop', |
|
|
|
) |
|
|
|
|
|
|
|
T = TypeVar('T') |
|
|
|
_func = Callable[..., Awaitable[Any]] |
|
|
|
LF = TypeVar('LF', bound=_func) |
|
|
|
FT = TypeVar('FT', bound=_func) |
|
|
|
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) |
|
|
|
LT = TypeVar('LT', bound='Loop') |
|
|
|
|
|
|
|
|
|
|
|
class SleepHandle: |
|
|
|
__slots__ = ('future', 'loop', 'handle') |
|
|
|
|
|
|
|
def __init__(self, dt, *, loop): |
|
|
|
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
self.loop = loop |
|
|
|
self.future = future = loop.create_future() |
|
|
|
relative_delta = discord.utils.compute_timedelta(dt) |
|
|
|
self.handle = loop.call_later(relative_delta, future.set_result, True) |
|
|
|
|
|
|
|
def recalculate(self, dt): |
|
|
|
def recalculate(self, dt: datetime.datetime) -> None: |
|
|
|
self.handle.cancel() |
|
|
|
relative_delta = discord.utils.compute_timedelta(dt) |
|
|
|
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) |
|
|
|
|
|
|
|
def wait(self): |
|
|
|
def wait(self) -> asyncio.Future: |
|
|
|
return self.future |
|
|
|
|
|
|
|
def done(self): |
|
|
|
def done(self) -> bool: |
|
|
|
return self.future.done() |
|
|
|
|
|
|
|
def cancel(self): |
|
|
|
def cancel(self) -> None: |
|
|
|
self.handle.cancel() |
|
|
|
self.future.cancel() |
|
|
|
|
|
|
|
|
|
|
|
class Loop: |
|
|
|
class Loop(Generic[LF]): |
|
|
|
"""A background task helper that abstracts the loop and reconnection logic for you. |
|
|
|
|
|
|
|
The main interface to create this is through :func:`loop`. |
|
|
|
""" |
|
|
|
def __init__(self, coro, seconds, hours, minutes, time, count, reconnect, loop): |
|
|
|
self.coro = coro |
|
|
|
self.reconnect = reconnect |
|
|
|
self.loop = loop |
|
|
|
self.count = count |
|
|
|
def __init__(self, |
|
|
|
coro: LF, |
|
|
|
seconds: float, |
|
|
|
hours: float, |
|
|
|
minutes: float, |
|
|
|
time: Union[datetime.time, Sequence[datetime.time]], |
|
|
|
count: Optional[int], |
|
|
|
reconnect: bool, |
|
|
|
loop: Optional[asyncio.AbstractEventLoop], |
|
|
|
) -> None: |
|
|
|
self.coro: LF = coro |
|
|
|
self.reconnect: bool = reconnect |
|
|
|
self.loop: Optional[asyncio.AbstractEventLoop] = loop |
|
|
|
self.count: Optional[int] = count |
|
|
|
self._current_loop = 0 |
|
|
|
self._handle = None |
|
|
|
self._task = None |
|
|
@ -104,7 +137,7 @@ class Loop: |
|
|
|
if not inspect.iscoroutinefunction(self.coro): |
|
|
|
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') |
|
|
|
|
|
|
|
async def _call_loop_function(self, name, *args, **kwargs): |
|
|
|
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: |
|
|
|
coro = getattr(self, '_' + name) |
|
|
|
if coro is None: |
|
|
|
return |
|
|
@ -114,16 +147,16 @@ class Loop: |
|
|
|
else: |
|
|
|
await coro(*args, **kwargs) |
|
|
|
|
|
|
|
def _try_sleep_until(self, dt): |
|
|
|
self._handle = SleepHandle(dt=dt, loop=self.loop) |
|
|
|
|
|
|
|
def _try_sleep_until(self, dt: datetime.datetime): |
|
|
|
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore |
|
|
|
return self._handle.wait() |
|
|
|
|
|
|
|
async def _loop(self, *args, **kwargs): |
|
|
|
async def _loop(self, *args: Any, **kwargs: Any) -> None: |
|
|
|
backoff = ExponentialBackoff() |
|
|
|
await self._call_loop_function('before_loop') |
|
|
|
sleep_until = discord.utils.sleep_until |
|
|
|
self._last_iteration_failed = False |
|
|
|
if self._time is not None: |
|
|
|
if self._time is not MISSING: |
|
|
|
# the time index should be prepared every time the internal loop is started |
|
|
|
self._prepare_time_index() |
|
|
|
self._next_iteration = self._get_next_sleep_time() |
|
|
@ -174,7 +207,7 @@ class Loop: |
|
|
|
self._stop_next_iteration = False |
|
|
|
self._has_failed = False |
|
|
|
|
|
|
|
def __get__(self, obj, objtype): |
|
|
|
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]: |
|
|
|
if obj is None: |
|
|
|
return self |
|
|
|
|
|
|
@ -183,8 +216,8 @@ class Loop: |
|
|
|
seconds=self._seconds, |
|
|
|
hours=self._hours, |
|
|
|
minutes=self._minutes, |
|
|
|
count=self.count, |
|
|
|
time=self._time, |
|
|
|
count=self.count, |
|
|
|
reconnect=self.reconnect, |
|
|
|
loop=self.loop, |
|
|
|
) |
|
|
@ -196,49 +229,52 @@ class Loop: |
|
|
|
return copy |
|
|
|
|
|
|
|
@property |
|
|
|
def seconds(self): |
|
|
|
def seconds(self) -> Optional[float]: |
|
|
|
"""Optional[:class:`float`]: Read-only value for the number of seconds |
|
|
|
between each iteration. ``None`` if an explicit ``time`` value was passed instead. |
|
|
|
|
|
|
|
.. versionadded:: 2.0 |
|
|
|
""" |
|
|
|
return self._seconds |
|
|
|
if self._seconds is not MISSING: |
|
|
|
return self._seconds |
|
|
|
|
|
|
|
@property |
|
|
|
def minutes(self): |
|
|
|
def minutes(self) -> Optional[float]: |
|
|
|
"""Optional[:class:`float`]: Read-only value for the number of minutes |
|
|
|
between each iteration. ``None`` if an explicit ``time`` value was passed instead. |
|
|
|
|
|
|
|
.. versionadded:: 2.0 |
|
|
|
""" |
|
|
|
return self._minutes |
|
|
|
if self._minutes is not MISSING: |
|
|
|
return self._minutes |
|
|
|
|
|
|
|
@property |
|
|
|
def hours(self): |
|
|
|
def hours(self) -> Optional[float]: |
|
|
|
"""Optional[:class:`float`]: Read-only value for the number of hours |
|
|
|
between each iteration. ``None`` if an explicit ``time`` value was passed instead. |
|
|
|
|
|
|
|
.. versionadded:: 2.0 |
|
|
|
""" |
|
|
|
return self._hours |
|
|
|
if self._hours is not MISSING: |
|
|
|
return self._hours |
|
|
|
|
|
|
|
@property |
|
|
|
def time(self): |
|
|
|
def time(self) -> Optional[List[datetime.time]]: |
|
|
|
"""Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at. |
|
|
|
``None`` if relative times were passed instead. |
|
|
|
|
|
|
|
.. versionadded:: 2.0 |
|
|
|
""" |
|
|
|
if self._time is not None: |
|
|
|
if self._time is not MISSING: |
|
|
|
return self._time.copy() |
|
|
|
|
|
|
|
@property |
|
|
|
def current_loop(self): |
|
|
|
def current_loop(self) -> int: |
|
|
|
""":class:`int`: The current iteration of the loop.""" |
|
|
|
return self._current_loop |
|
|
|
|
|
|
|
@property |
|
|
|
def next_iteration(self): |
|
|
|
def next_iteration(self) -> Optional[datetime.datetime]: |
|
|
|
"""Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur. |
|
|
|
|
|
|
|
.. versionadded:: 1.3 |
|
|
@ -249,7 +285,7 @@ class Loop: |
|
|
|
return None |
|
|
|
return self._next_iteration |
|
|
|
|
|
|
|
async def __call__(self, *args, **kwargs): |
|
|
|
async def __call__(self, *args: Any, **kwargs: Any) -> Any: |
|
|
|
r"""|coro| |
|
|
|
|
|
|
|
Calls the internal callback that the task holds. |
|
|
@ -269,7 +305,7 @@ class Loop: |
|
|
|
|
|
|
|
return await self.coro(*args, **kwargs) |
|
|
|
|
|
|
|
def start(self, *args, **kwargs): |
|
|
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: |
|
|
|
r"""Starts the internal task in the event loop. |
|
|
|
|
|
|
|
Parameters |
|
|
@ -302,7 +338,7 @@ class Loop: |
|
|
|
self._task = self.loop.create_task(self._loop(*args, **kwargs)) |
|
|
|
return self._task |
|
|
|
|
|
|
|
def stop(self): |
|
|
|
def stop(self) -> None: |
|
|
|
r"""Gracefully stops the task from running. |
|
|
|
|
|
|
|
Unlike :meth:`cancel`\, this allows the task to finish its |
|
|
@ -323,15 +359,15 @@ class Loop: |
|
|
|
if self._task and not self._task.done(): |
|
|
|
self._stop_next_iteration = True |
|
|
|
|
|
|
|
def _can_be_cancelled(self): |
|
|
|
return not self._is_being_cancelled and self._task and not self._task.done() |
|
|
|
def _can_be_cancelled(self) -> bool: |
|
|
|
return bool(not self._is_being_cancelled and self._task and not self._task.done()) |
|
|
|
|
|
|
|
def cancel(self): |
|
|
|
def cancel(self) -> None: |
|
|
|
"""Cancels the internal task, if it is running.""" |
|
|
|
if self._can_be_cancelled(): |
|
|
|
self._task.cancel() |
|
|
|
|
|
|
|
def restart(self, *args, **kwargs): |
|
|
|
def restart(self, *args: Any, **kwargs: Any) -> None: |
|
|
|
r"""A convenience method to restart the internal task. |
|
|
|
|
|
|
|
.. note:: |
|
|
@ -355,7 +391,7 @@ class Loop: |
|
|
|
self._task.add_done_callback(restart_when_over) |
|
|
|
self._task.cancel() |
|
|
|
|
|
|
|
def add_exception_type(self, *exceptions): |
|
|
|
def add_exception_type(self, *exceptions: Type[BaseException]) -> None: |
|
|
|
r"""Adds exception types to be handled during the reconnect logic. |
|
|
|
|
|
|
|
By default the exception types handled are those handled by |
|
|
@ -384,7 +420,7 @@ class Loop: |
|
|
|
|
|
|
|
self._valid_exception = (*self._valid_exception, *exceptions) |
|
|
|
|
|
|
|
def clear_exception_types(self): |
|
|
|
def clear_exception_types(self) -> None: |
|
|
|
"""Removes all exception types that are handled. |
|
|
|
|
|
|
|
.. note:: |
|
|
@ -393,7 +429,7 @@ class Loop: |
|
|
|
""" |
|
|
|
self._valid_exception = tuple() |
|
|
|
|
|
|
|
def remove_exception_type(self, *exceptions): |
|
|
|
def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool: |
|
|
|
r"""Removes exception types from being handled during the reconnect logic. |
|
|
|
|
|
|
|
Parameters |
|
|
@ -410,34 +446,34 @@ class Loop: |
|
|
|
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) |
|
|
|
return len(self._valid_exception) == old_length - len(exceptions) |
|
|
|
|
|
|
|
def get_task(self): |
|
|
|
def get_task(self) -> Optional[asyncio.Task]: |
|
|
|
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" |
|
|
|
return self._task |
|
|
|
|
|
|
|
def is_being_cancelled(self): |
|
|
|
def is_being_cancelled(self) -> bool: |
|
|
|
"""Whether the task is being cancelled.""" |
|
|
|
return self._is_being_cancelled |
|
|
|
|
|
|
|
def failed(self): |
|
|
|
def failed(self) -> bool: |
|
|
|
""":class:`bool`: Whether the internal task has failed. |
|
|
|
|
|
|
|
.. versionadded:: 1.2 |
|
|
|
""" |
|
|
|
return self._has_failed |
|
|
|
|
|
|
|
def is_running(self): |
|
|
|
def is_running(self) -> bool: |
|
|
|
""":class:`bool`: Check if the task is currently running. |
|
|
|
|
|
|
|
.. versionadded:: 1.4 |
|
|
|
""" |
|
|
|
return not bool(self._task.done()) if self._task else False |
|
|
|
|
|
|
|
async def _error(self, *args): |
|
|
|
exception = args[-1] |
|
|
|
async def _error(self, *args: Any) -> None: |
|
|
|
exception: Exception = args[-1] |
|
|
|
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) |
|
|
|
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) |
|
|
|
|
|
|
|
def before_loop(self, coro): |
|
|
|
def before_loop(self, coro: FT) -> FT: |
|
|
|
"""A decorator that registers a coroutine to be called before the loop starts running. |
|
|
|
|
|
|
|
This is useful if you want to wait for some bot state before the loop starts, |
|
|
@ -462,7 +498,7 @@ class Loop: |
|
|
|
self._before_loop = coro |
|
|
|
return coro |
|
|
|
|
|
|
|
def after_loop(self, coro): |
|
|
|
def after_loop(self, coro: FT) -> FT: |
|
|
|
"""A decorator that register a coroutine to be called after the loop finished running. |
|
|
|
|
|
|
|
The coroutine must take no arguments (except ``self`` in a class context). |
|
|
@ -490,7 +526,7 @@ class Loop: |
|
|
|
self._after_loop = coro |
|
|
|
return coro |
|
|
|
|
|
|
|
def error(self, coro): |
|
|
|
def error(self, coro: ET) -> ET: |
|
|
|
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception. |
|
|
|
|
|
|
|
The coroutine must take only one argument the exception raised (except ``self`` in a class context). |
|
|
@ -513,11 +549,11 @@ class Loop: |
|
|
|
if not inspect.iscoroutinefunction(coro): |
|
|
|
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') |
|
|
|
|
|
|
|
self._error = coro |
|
|
|
self._error = coro # type: ignore |
|
|
|
return coro |
|
|
|
|
|
|
|
def _get_next_sleep_time(self): |
|
|
|
if self._sleep is not None: |
|
|
|
def _get_next_sleep_time(self) -> datetime.datetime: |
|
|
|
if self._sleep is not MISSING: |
|
|
|
return self._last_iteration + datetime.timedelta(seconds=self._sleep) |
|
|
|
|
|
|
|
if self._time_index >= len(self._time): |
|
|
@ -532,7 +568,7 @@ class Loop: |
|
|
|
self._time_index += 1 |
|
|
|
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) |
|
|
|
|
|
|
|
next_date = self._last_iteration |
|
|
|
next_date = cast(datetime.datetime, self._last_iteration) |
|
|
|
if self._time_index == 0: |
|
|
|
# we can assume that the earliest time should be scheduled for "tomorrow" |
|
|
|
next_date += datetime.timedelta(days=1) |
|
|
@ -540,7 +576,7 @@ class Loop: |
|
|
|
self._time_index += 1 |
|
|
|
return datetime.datetime.combine(next_date, next_time) |
|
|
|
|
|
|
|
def _prepare_time_index(self, now=None): |
|
|
|
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None: |
|
|
|
# now kwarg should be a datetime.datetime representing the time "now" |
|
|
|
# to calculate the next time index from |
|
|
|
|
|
|
@ -553,25 +589,38 @@ class Loop: |
|
|
|
else: |
|
|
|
self._time_index = 0 |
|
|
|
|
|
|
|
def _get_time_parameter(self, time, *, inst=isinstance, dt=datetime.time, utc=datetime.timezone.utc): |
|
|
|
if inst(time, dt): |
|
|
|
def _get_time_parameter( |
|
|
|
self, |
|
|
|
time: Union[datetime.time, Sequence[datetime.time]], |
|
|
|
*, |
|
|
|
dt: Type[datetime.time] = datetime.time, |
|
|
|
utc: datetime.timezone = datetime.timezone.utc, |
|
|
|
) -> List[datetime.time]: |
|
|
|
if isinstance(time, dt): |
|
|
|
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) |
|
|
|
return [ret] |
|
|
|
if not inst(time, Sequence): |
|
|
|
if not isinstance(time, Sequence): |
|
|
|
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') |
|
|
|
if not time: |
|
|
|
raise ValueError('time parameter must not be an empty sequence.') |
|
|
|
|
|
|
|
ret = [] |
|
|
|
for index, t in enumerate(time): |
|
|
|
if not inst(t, dt): |
|
|
|
if not isinstance(t, dt): |
|
|
|
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') |
|
|
|
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) |
|
|
|
|
|
|
|
ret = sorted(set(ret)) # de-dupe and sort times |
|
|
|
return ret |
|
|
|
|
|
|
|
def change_interval(self, *, seconds=0, minutes=0, hours=0, time=None): |
|
|
|
def change_interval( |
|
|
|
self, |
|
|
|
*, |
|
|
|
seconds: float = 0, |
|
|
|
minutes: float = 0, |
|
|
|
hours: float = 0, |
|
|
|
time: Union[datetime.time, Sequence[datetime.time]] = MISSING, |
|
|
|
) -> None: |
|
|
|
"""Changes the interval for the sleep time. |
|
|
|
|
|
|
|
.. versionadded:: 1.2 |
|
|
@ -604,7 +653,10 @@ class Loop: |
|
|
|
``time`` parameter was passed in conjunction with relative time parameters. |
|
|
|
""" |
|
|
|
|
|
|
|
if time is None: |
|
|
|
if time is MISSING: |
|
|
|
seconds = seconds or 0 |
|
|
|
minutes = minutes or 0 |
|
|
|
hours = hours or 0 |
|
|
|
sleep = seconds + (minutes * 60.0) + (hours * 3600.0) |
|
|
|
if sleep < 0: |
|
|
|
raise ValueError('Total number of seconds cannot be less than zero.') |
|
|
@ -613,12 +665,12 @@ class Loop: |
|
|
|
self._seconds = float(seconds) |
|
|
|
self._hours = float(hours) |
|
|
|
self._minutes = float(minutes) |
|
|
|
self._time = None |
|
|
|
self._time: List[datetime.time] = MISSING |
|
|
|
else: |
|
|
|
if any((seconds, minutes, hours)): |
|
|
|
raise TypeError('Cannot mix explicit time with relative time') |
|
|
|
self._time = self._get_time_parameter(time) |
|
|
|
self._sleep = self._seconds = self._minutes = self._hours = None |
|
|
|
self._sleep = self._seconds = self._minutes = self._hours = MISSING |
|
|
|
|
|
|
|
if self.is_running(): |
|
|
|
if self._time is not None: |
|
|
@ -631,7 +683,16 @@ class Loop: |
|
|
|
self._handle.recalculate(self._next_iteration) |
|
|
|
|
|
|
|
|
|
|
|
def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True, loop=None): |
|
|
|
def loop( |
|
|
|
*, |
|
|
|
seconds: float = MISSING, |
|
|
|
minutes: float = MISSING, |
|
|
|
hours: float = MISSING, |
|
|
|
time: Union[datetime.time, Sequence[datetime.time]] = MISSING, |
|
|
|
count: Optional[int] = None, |
|
|
|
reconnect: bool = True, |
|
|
|
loop: Optional[asyncio.AbstractEventLoop] = None, |
|
|
|
) -> Callable[[LF], Loop[LF]]: |
|
|
|
"""A decorator that schedules a task in the background for you with |
|
|
|
optional reconnect logic. The decorator returns a :class:`Loop`. |
|
|
|
|
|
|
@ -663,7 +724,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True |
|
|
|
Whether to handle errors and restart the task |
|
|
|
using an exponential back-off algorithm similar to the |
|
|
|
one used in :meth:`discord.Client.connect`. |
|
|
|
loop: :class:`asyncio.AbstractEventLoop` |
|
|
|
loop: Optional[:class:`asyncio.AbstractEventLoop`] |
|
|
|
The loop to use to register the task, if not given |
|
|
|
defaults to :func:`asyncio.get_event_loop`. |
|
|
|
|
|
|
@ -675,7 +736,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True |
|
|
|
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, |
|
|
|
or ``time`` parameter was passed in conjunction with relative time parameters. |
|
|
|
""" |
|
|
|
def decorator(func): |
|
|
|
def decorator(func: LF) -> Loop[LF]: |
|
|
|
kwargs = { |
|
|
|
'seconds': seconds, |
|
|
|
'minutes': minutes, |
|
|
@ -683,7 +744,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True |
|
|
|
'count': count, |
|
|
|
'time': time, |
|
|
|
'reconnect': reconnect, |
|
|
|
'loop': loop |
|
|
|
'loop': loop, |
|
|
|
} |
|
|
|
return Loop(func, **kwargs) |
|
|
|
return decorator |
|
|
|