diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index e3254cebd..2339843a4 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -148,13 +148,17 @@ class Loop(Generic[LF]): self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop()) return self._handle.wait() + def _is_relative_time(self) -> bool: + return self._time is MISSING + + def _is_explicit_time(self) -> bool: + return self._time is not MISSING + async def _loop(self, *args: Any, **kwargs: Any) -> None: backoff = ExponentialBackoff() await self._call_loop_function('before_loop') self._last_iteration_failed = False - if self._time is not MISSING: - # the time index should be prepared every time the internal loop is started - self._prepare_time_index() + if self._is_explicit_time(): self._next_iteration = self._get_next_sleep_time() else: self._next_iteration = datetime.datetime.now(datetime.timezone.utc) @@ -164,7 +168,7 @@ class Loop(Generic[LF]): return while True: # sleep before the body of the task for explicit time intervals - if self._time is not MISSING: + if self._is_explicit_time(): await self._try_sleep_until(self._next_iteration) if not self._last_iteration_failed: self._last_iteration = self._next_iteration @@ -182,7 +186,7 @@ class Loop(Generic[LF]): return # sleep after the body of the task for relative time intervals - if self._time is MISSING: + if self._is_relative_time(): await self._try_sleep_until(self._next_iteration) self._current_loop += 1 @@ -553,47 +557,36 @@ class Loop(Generic[LF]): self._error = coro # type: ignore return coro - def _get_next_sleep_time(self) -> datetime.datetime: + def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime: if self._sleep is not MISSING: return self._last_iteration + datetime.timedelta(seconds=self._sleep) - if self._time_index >= len(self._time): - self._time_index = 0 - if self._current_loop == 0: - # if we're at the last index on the first iteration, we need to sleep until tomorrow - return datetime.datetime.combine( - datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0] - ) - - next_time = self._time[self._time_index] + if now is MISSING: + now = datetime.datetime.now(datetime.timezone.utc) - if self._current_loop == 0: - self._time_index += 1 - return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) + index = self._start_time_relative_to(now) - next_date = self._last_iteration - if self._time_index == 0: - # we can assume that the earliest time should be scheduled for "tomorrow" - next_date += datetime.timedelta(days=1) + if index is None: + time = self._time[0] + tomorrow = now + datetime.timedelta(days=1) + date = tomorrow.date() + else: + date = now.date() + time = self._time[index] - self._time_index += 1 - return datetime.datetime.combine(next_date, next_time) + return datetime.datetime.combine(date, time, tzinfo=time.tzinfo or datetime.timezone.utc) - def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None: + def _start_time_relative_to(self, now: datetime.datetime) -> Optional[int]: # now kwarg should be a datetime.datetime representing the time "now" # to calculate the next time index from # pre-condition: self._time is set - time_now = ( - now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) - ).timetz() - idx = -1 + time_now = now.timetz() for idx, time in enumerate(self._time): if time >= time_now: - self._time_index = idx - break + return idx else: - self._time_index = idx + 1 + return None def _get_time_parameter( self, @@ -683,10 +676,6 @@ class Loop(Generic[LF]): self._sleep = self._seconds = self._minutes = self._hours = MISSING if self.is_running(): - if self._time is not MISSING: - # prepare the next time index starting from after the last iteration - self._prepare_time_index(now=self._last_iteration) - self._next_iteration = self._get_next_sleep_time() if self._handle and not self._handle.done(): # the loop is sleeping, recalculate based on new interval @@ -701,7 +690,6 @@ def loop( time: Union[datetime.time, Sequence[datetime.time]] = MISSING, count: Optional[int] = None, reconnect: bool = True, - loop: asyncio.AbstractEventLoop = MISSING, ) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -745,6 +733,14 @@ def loop( """ def decorator(func: LF) -> Loop[LF]: - return Loop[LF](func, seconds=seconds, minutes=minutes, hours=hours, count=count, time=time, reconnect=reconnect) + return Loop[LF]( + func, + seconds=seconds, + minutes=minutes, + hours=hours, + count=count, + time=time, + reconnect=reconnect, + ) return decorator diff --git a/tests/test_ext_tasks.py b/tests/test_ext_tasks.py index 4f4c9591c..a0337ffb3 100644 --- a/tests/test_ext_tasks.py +++ b/tests/test_ext_tasks.py @@ -75,3 +75,23 @@ async def test_explicit_initial_runs_tomorrow_multi(): assert not has_run finally: loop.cancel() + + +def test_task_regression_issue7659(): + jst = datetime.timezone(datetime.timedelta(hours=9)) + + # 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00 + times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)] + + @tasks.loop(time=times) + async def loop(): + pass + + before_midnight = datetime.datetime(2022, 3, 12, 23, 50, 59, tzinfo=jst) + after_midnight = before_midnight + datetime.timedelta(minutes=9, seconds=2) + + expected_before_midnight = datetime.datetime(2022, 3, 13, 0, 0, 0, tzinfo=jst) + expected_after_midnight = datetime.datetime(2022, 3, 13, 3, 0, 0, tzinfo=jst) + + assert loop._get_next_sleep_time(before_midnight) == expected_before_midnight + assert loop._get_next_sleep_time(after_midnight) == expected_after_midnight