Browse Source

[tasks] Refactor tasks to not store a time index state

It's better to recompute it every time rather than suffer from
maintaining the extra state.
pull/7674/head
Rapptz 3 years ago
parent
commit
6a43d60acf
  1. 72
      discord/ext/tasks/__init__.py
  2. 20
      tests/test_ext_tasks.py

72
discord/ext/tasks/__init__.py

@ -148,13 +148,17 @@ class Loop(Generic[LF]):
self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
return self._handle.wait()
def _is_relative_time(self) -> bool:
return self._time is MISSING
def _is_explicit_time(self) -> bool:
return self._time is not MISSING
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
self._prepare_time_index()
if self._is_explicit_time():
self._next_iteration = self._get_next_sleep_time()
else:
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
@ -164,7 +168,7 @@ class Loop(Generic[LF]):
return
while True:
# sleep before the body of the task for explicit time intervals
if self._time is not MISSING:
if self._is_explicit_time():
await self._try_sleep_until(self._next_iteration)
if not self._last_iteration_failed:
self._last_iteration = self._next_iteration
@ -182,7 +186,7 @@ class Loop(Generic[LF]):
return
# sleep after the body of the task for relative time intervals
if self._time is MISSING:
if self._is_relative_time():
await self._try_sleep_until(self._next_iteration)
self._current_loop += 1
@ -553,47 +557,36 @@ class Loop(Generic[LF]):
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime:
def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime:
if self._sleep is not MISSING:
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
if self._time_index >= len(self._time):
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index]
if now is MISSING:
now = datetime.datetime.now(datetime.timezone.utc)
if self._current_loop == 0:
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
index = self._start_time_relative_to(now)
next_date = self._last_iteration
if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1)
if index is None:
time = self._time[0]
tomorrow = now + datetime.timedelta(days=1)
date = tomorrow.date()
else:
date = now.date()
time = self._time[index]
self._time_index += 1
return datetime.datetime.combine(next_date, next_time)
return datetime.datetime.combine(date, time, tzinfo=time.tzinfo or datetime.timezone.utc)
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
def _start_time_relative_to(self, now: datetime.datetime) -> Optional[int]:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from
# pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
idx = -1
time_now = now.timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
break
return idx
else:
self._time_index = idx + 1
return None
def _get_time_parameter(
self,
@ -683,10 +676,6 @@ class Loop(Generic[LF]):
self._sleep = self._seconds = self._minutes = self._hours = MISSING
if self.is_running():
if self._time is not MISSING:
# prepare the next time index starting from after the last iteration
self._prepare_time_index(now=self._last_iteration)
self._next_iteration = self._get_next_sleep_time()
if self._handle and not self._handle.done():
# the loop is sleeping, recalculate based on new interval
@ -701,7 +690,6 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@ -745,6 +733,14 @@ def loop(
"""
def decorator(func: LF) -> Loop[LF]:
return Loop[LF](func, seconds=seconds, minutes=minutes, hours=hours, count=count, time=time, reconnect=reconnect)
return Loop[LF](
func,
seconds=seconds,
minutes=minutes,
hours=hours,
count=count,
time=time,
reconnect=reconnect,
)
return decorator

20
tests/test_ext_tasks.py

@ -75,3 +75,23 @@ async def test_explicit_initial_runs_tomorrow_multi():
assert not has_run
finally:
loop.cancel()
def test_task_regression_issue7659():
jst = datetime.timezone(datetime.timedelta(hours=9))
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
@tasks.loop(time=times)
async def loop():
pass
before_midnight = datetime.datetime(2022, 3, 12, 23, 50, 59, tzinfo=jst)
after_midnight = before_midnight + datetime.timedelta(minutes=9, seconds=2)
expected_before_midnight = datetime.datetime(2022, 3, 13, 0, 0, 0, tzinfo=jst)
expected_after_midnight = datetime.datetime(2022, 3, 13, 3, 0, 0, tzinfo=jst)
assert loop._get_next_sleep_time(before_midnight) == expected_before_midnight
assert loop._get_next_sleep_time(after_midnight) == expected_after_midnight

Loading…
Cancel
Save