diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index d2ae2c750..81e8dc79d 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -31,6 +31,7 @@ import logging import sys import traceback +from collections.abc import Sequence from discord.backoff import ExponentialBackoff log = logging.getLogger(__name__) @@ -39,17 +40,43 @@ __all__ = ( 'loop', ) +class SleepHandle: + __slots__ = ('future', 'loop', 'handle') + + def __init__(self, dt, *, loop): + self.loop = loop + self.future = future = loop.create_future() + relative_delta = discord.utils.compute_timedelta(dt) + self.handle = loop.call_later(relative_delta, future.set_result, True) + + def recalculate(self, dt): + self.handle.cancel() + relative_delta = discord.utils.compute_timedelta(dt) + self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) + + def wait(self): + return self.future + + def done(self): + return self.future.done() + + def cancel(self): + self.handle.cancel() + self.future.cancel() + + class Loop: """A background task helper that abstracts the loop and reconnection logic for you. The main interface to create this is through :func:`loop`. """ - def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop): + def __init__(self, coro, seconds, hours, minutes, time, count, reconnect, loop): self.coro = coro self.reconnect = reconnect self.loop = loop self.count = count self._current_loop = 0 + self._handle = None self._task = None self._injected = None self._valid_exception = ( @@ -69,7 +96,7 @@ class Loop: if self.count is not None and self.count <= 0: raise ValueError('count must be greater than 0 or None.') - self.change_interval(seconds=seconds, minutes=minutes, hours=hours) + self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self._last_iteration_failed = False self._last_iteration = None self._next_iteration = None @@ -87,14 +114,23 @@ class Loop: else: await coro(*args, **kwargs) + def _try_sleep_until(self, dt): + self._handle = SleepHandle(dt=dt, loop=self.loop) + return self._handle.wait() + async def _loop(self, *args, **kwargs): backoff = ExponentialBackoff() await self._call_loop_function('before_loop') sleep_until = discord.utils.sleep_until self._last_iteration_failed = False - self._next_iteration = datetime.datetime.now(datetime.timezone.utc) + if self._time is not None: + # the time index should be prepared every time the internal loop is started + self._prepare_time_index() + self._next_iteration = self._get_next_sleep_time() + else: + self._next_iteration = datetime.datetime.now(datetime.timezone.utc) try: - await asyncio.sleep(0) # allows canceling in before_loop + await self._try_sleep_until(self._next_iteration) while True: if not self._last_iteration_failed: self._last_iteration = self._next_iteration @@ -102,22 +138,26 @@ class Loop: try: await self.coro(*args, **kwargs) self._last_iteration_failed = False - now = datetime.datetime.now(datetime.timezone.utc) - if now > self._next_iteration: - self._next_iteration = now except self._valid_exception: self._last_iteration_failed = True if not self.reconnect: raise await asyncio.sleep(backoff.delay()) else: - await sleep_until(self._next_iteration) + await self._try_sleep_until(self._next_iteration) if self._stop_next_iteration: return + + now = datetime.datetime.now(datetime.timezone.utc) + if now > self._next_iteration: + self._prepare_time_index(now) + self._next_iteration = now + self._current_loop += 1 if self._current_loop == self.count: break + except asyncio.CancelledError: self._is_being_cancelled = True raise @@ -127,6 +167,7 @@ class Loop: raise exc finally: await self._call_loop_function('after_loop') + self._handle.cancel() self._is_being_cancelled = False self._current_loop = 0 self._stop_next_iteration = False @@ -136,8 +177,16 @@ class Loop: if obj is None: return self - copy = Loop(self.coro, seconds=self.seconds, hours=self.hours, minutes=self.minutes, - count=self.count, reconnect=self.reconnect, loop=self.loop) + copy = Loop( + self.coro, + seconds=self._seconds, + hours=self._hours, + minutes=self._minutes, + count=self.count, + time=self._time, + reconnect=self.reconnect, + loop=self.loop, + ) copy._injected = obj copy._before_loop = self._before_loop copy._after_loop = self._after_loop @@ -145,6 +194,43 @@ class Loop: setattr(obj, self.coro.__name__, copy) return copy + @property + def seconds(self): + """Optional[:class:`float`]: Read-only value for the number of seconds + between each iteration. ``None`` if an explicit ``time`` value was passed instead. + + .. versionadded:: 2.0 + """ + return self._seconds + + @property + def minutes(self): + """Optional[:class:`float`]: Read-only value for the number of minutes + between each iteration. ``None`` if an explicit ``time`` value was passed instead. + + .. versionadded:: 2.0 + """ + return self._minutes + + @property + def hours(self): + """Optional[:class:`float`]: Read-only value for the number of hours + between each iteration. ``None`` if an explicit ``time`` value was passed instead. + + .. versionadded:: 2.0 + """ + return self._hours + + @property + def time(self): + """Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at. + ``None`` if relative times were passed instead. + + .. versionadded:: 2.0 + """ + if self._time is not None: + return self._time.copy() + @property def current_loop(self): """:class:`int`: The current iteration of the loop.""" @@ -430,16 +516,63 @@ class Loop: return coro def _get_next_sleep_time(self): - return self._last_iteration + datetime.timedelta(seconds=self._sleep) - - def change_interval(self, *, seconds=0, minutes=0, hours=0): + if self._sleep is not None: + return self._last_iteration + datetime.timedelta(seconds=self._sleep) + + if self._time_index >= len(self._time): + self._time_index = 0 + if self._current_loop == 0: + # if we're at the last index on the first iteration, we need to sleep until tomorrow + return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]) + + next_time = self._time[self._time_index] + + if self._current_loop == 0: + self._time_index += 1 + return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) + + next_date = self._last_iteration + if self._time_index == 0: + # we can assume that the earliest time should be scheduled for "tomorrow" + next_date += datetime.timedelta(days=1) + + self._time_index += 1 + return datetime.datetime.combine(next_date, next_time) + + def _prepare_time_index(self, now=None): + # now kwarg should be a datetime.datetime representing the time "now" + # to calculate the next time index from + + # pre-condition: self._time is set + time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz() + for idx, time in enumerate(self._time): + if time >= time_now: + self._time_index = idx + break + else: + self._time_index = 0 + + def _get_time_parameter(self, time, *, inst=isinstance, dt=datetime.time, utc=datetime.timezone.utc): + if inst(time, dt): + ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) + return [ret] + if not inst(time, Sequence): + raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') + if not time: + raise ValueError('time parameter must not be an empty sequence.') + + ret = [] + for index, t in enumerate(time): + if not inst(t, dt): + raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') + ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) + + ret = sorted(set(ret)) # de-dupe and sort times + return ret + + def change_interval(self, *, seconds=0, minutes=0, hours=0, time=None): """Changes the interval for the sleep time. - .. note:: - - This only applies on the next loop iteration. If it is desirable for the change of interval - to be applied right away, cancel the task with :meth:`cancel`. - .. versionadded:: 1.2 Parameters @@ -450,23 +583,54 @@ class Loop: The number of minutes between every iteration. hours: :class:`float` The number of hours between every iteration. + time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]] + The exact times to run this loop at. Either a non-empty list or a single + value of :class:`datetime.time` should be passed. + This cannot be used in conjunction with the relative time parameters. + + .. versionadded:: 2.0 + + .. note:: + + Duplicate times will be ignored, and only run once. Raises ------- ValueError An invalid value was given. + TypeError + An invalid value for the ``time`` parameter was passed, or the + ``time`` parameter was passed in conjunction with relative time parameters. """ - sleep = seconds + (minutes * 60.0) + (hours * 3600.0) - if sleep < 0: - raise ValueError('Total number of seconds cannot be less than zero.') + if time is None: + sleep = seconds + (minutes * 60.0) + (hours * 3600.0) + if sleep < 0: + raise ValueError('Total number of seconds cannot be less than zero.') + + self._sleep = sleep + self._seconds = float(seconds) + self._hours = float(hours) + self._minutes = float(minutes) + self._time = None + else: + if any((seconds, minutes, hours)): + raise TypeError('Cannot mix explicit time with relative time') + self._time = self._get_time_parameter(time) + self._sleep = self._seconds = self._minutes = self._hours = None + + if self.is_running(): + if self._time is not None: + # prepare the next time index starting from after the last iteration + self._prepare_time_index(now=self._last_iteration) + + self._next_iteration = self._get_next_sleep_time() + if not self._handle.done(): + # the loop is sleeping, recalculate based on new interval + self._handle.recalculate(self._next_iteration) - self._sleep = sleep - self.seconds = seconds - self.hours = hours - self.minutes = minutes -def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None): +def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True, loop=None): """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -478,6 +642,19 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None The number of minutes between every iteration. hours: :class:`float` The number of hours between every iteration. + time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]] + The exact times to run this loop at. Either a non-empty list or a single + value of :class:`datetime.time` should be passed. Timezones are supported. + If no timezone is given for the times, it is assumed to represent UTC time. + + This cannot be used in conjunction with the relative time parameters. + + .. note:: + + Duplicate times will be ignored, and only run once. + + .. versionadded:: 2.0 + count: Optional[:class:`int`] The number of loops to do, ``None`` if it should be an infinite loop. @@ -494,7 +671,8 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None ValueError An invalid value was given. TypeError - The function was not a coroutine. + The function was not a coroutine, an invalid value for the ``time`` parameter was passed, + or ``time`` parameter was passed in conjunction with relative time parameters. """ def decorator(func): kwargs = { @@ -502,6 +680,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None 'minutes': minutes, 'hours': hours, 'count': count, + 'time': time, 'reconnect': reconnect, 'loop': loop } diff --git a/discord/utils.py b/discord/utils.py index 32d8feb4a..22a0c40c9 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -503,6 +503,13 @@ async def sane_wait_for(futures, *, timeout): return done +def compute_timedelta(dt: datetime.datetime): + if dt.tzinfo is None: + dt = dt.astimezone() + now = datetime.datetime.now(datetime.timezone.utc) + return max((dt - now).total_seconds(), 0) + + async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]: """|coro| @@ -520,11 +527,8 @@ async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Op result: Any If provided is returned to the caller when the coroutine completes. """ - if when.tzinfo is None: - when = when.astimezone() - now = datetime.datetime.now(datetime.timezone.utc) - delta = (when - now).total_seconds() - return await asyncio.sleep(max(delta, 0), result) + delta = compute_timedelta(when) + return await asyncio.sleep(delta, result) def utcnow() -> datetime.datetime: