From ef22178deefc403cefff60c1a6983e4cc73a0f3e Mon Sep 17 00:00:00 2001 From: Josh Date: Wed, 12 May 2021 20:31:40 +1000 Subject: [PATCH] [tasks] Type hint the tasks extension --- discord/ext/tasks/__init__.py | 191 ++++++++++++++++++++++------------ 1 file changed, 126 insertions(+), 65 deletions(-) diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 940b9f889..dced93d47 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -22,8 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio import datetime +from typing import ( + Any, + Awaitable, + Callable, + Generic, + List, + Optional, + Type, + TypeVar, + Union, + cast, +) + import aiohttp import discord import inspect @@ -33,6 +48,7 @@ import traceback from collections.abc import Sequence from discord.backoff import ExponentialBackoff +from discord.utils import MISSING log = logging.getLogger(__name__) @@ -40,41 +56,58 @@ __all__ = ( 'loop', ) +T = TypeVar('T') +_func = Callable[..., Awaitable[Any]] +LF = TypeVar('LF', bound=_func) +FT = TypeVar('FT', bound=_func) +ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) +LT = TypeVar('LT', bound='Loop') + + class SleepHandle: __slots__ = ('future', 'loop', 'handle') - def __init__(self, dt, *, loop): + def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: self.loop = loop self.future = future = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) self.handle = loop.call_later(relative_delta, future.set_result, True) - def recalculate(self, dt): + def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) - def wait(self): + def wait(self) -> asyncio.Future: return self.future - def done(self): + def done(self) -> bool: return self.future.done() - def cancel(self): + def cancel(self) -> None: self.handle.cancel() self.future.cancel() -class Loop: +class Loop(Generic[LF]): """A background task helper that abstracts the loop and reconnection logic for you. The main interface to create this is through :func:`loop`. """ - def __init__(self, coro, seconds, hours, minutes, time, count, reconnect, loop): - self.coro = coro - self.reconnect = reconnect - self.loop = loop - self.count = count + def __init__(self, + coro: LF, + seconds: float, + hours: float, + minutes: float, + time: Union[datetime.time, Sequence[datetime.time]], + count: Optional[int], + reconnect: bool, + loop: Optional[asyncio.AbstractEventLoop], + ) -> None: + self.coro: LF = coro + self.reconnect: bool = reconnect + self.loop: Optional[asyncio.AbstractEventLoop] = loop + self.count: Optional[int] = count self._current_loop = 0 self._handle = None self._task = None @@ -104,7 +137,7 @@ class Loop: if not inspect.iscoroutinefunction(self.coro): raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') - async def _call_loop_function(self, name, *args, **kwargs): + async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: coro = getattr(self, '_' + name) if coro is None: return @@ -114,16 +147,16 @@ class Loop: else: await coro(*args, **kwargs) - def _try_sleep_until(self, dt): - self._handle = SleepHandle(dt=dt, loop=self.loop) + + def _try_sleep_until(self, dt: datetime.datetime): + self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore return self._handle.wait() - async def _loop(self, *args, **kwargs): + async def _loop(self, *args: Any, **kwargs: Any) -> None: backoff = ExponentialBackoff() await self._call_loop_function('before_loop') - sleep_until = discord.utils.sleep_until self._last_iteration_failed = False - if self._time is not None: + if self._time is not MISSING: # the time index should be prepared every time the internal loop is started self._prepare_time_index() self._next_iteration = self._get_next_sleep_time() @@ -174,7 +207,7 @@ class Loop: self._stop_next_iteration = False self._has_failed = False - def __get__(self, obj, objtype): + def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]: if obj is None: return self @@ -183,8 +216,8 @@ class Loop: seconds=self._seconds, hours=self._hours, minutes=self._minutes, - count=self.count, time=self._time, + count=self.count, reconnect=self.reconnect, loop=self.loop, ) @@ -196,49 +229,52 @@ class Loop: return copy @property - def seconds(self): + def seconds(self) -> Optional[float]: """Optional[:class:`float`]: Read-only value for the number of seconds between each iteration. ``None`` if an explicit ``time`` value was passed instead. .. versionadded:: 2.0 """ - return self._seconds + if self._seconds is not MISSING: + return self._seconds @property - def minutes(self): + def minutes(self) -> Optional[float]: """Optional[:class:`float`]: Read-only value for the number of minutes between each iteration. ``None`` if an explicit ``time`` value was passed instead. .. versionadded:: 2.0 """ - return self._minutes + if self._minutes is not MISSING: + return self._minutes @property - def hours(self): + def hours(self) -> Optional[float]: """Optional[:class:`float`]: Read-only value for the number of hours between each iteration. ``None`` if an explicit ``time`` value was passed instead. .. versionadded:: 2.0 """ - return self._hours + if self._hours is not MISSING: + return self._hours @property - def time(self): + def time(self) -> Optional[List[datetime.time]]: """Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at. ``None`` if relative times were passed instead. .. versionadded:: 2.0 """ - if self._time is not None: + if self._time is not MISSING: return self._time.copy() @property - def current_loop(self): + def current_loop(self) -> int: """:class:`int`: The current iteration of the loop.""" return self._current_loop @property - def next_iteration(self): + def next_iteration(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur. .. versionadded:: 1.3 @@ -249,7 +285,7 @@ class Loop: return None return self._next_iteration - async def __call__(self, *args, **kwargs): + async def __call__(self, *args: Any, **kwargs: Any) -> Any: r"""|coro| Calls the internal callback that the task holds. @@ -269,7 +305,7 @@ class Loop: return await self.coro(*args, **kwargs) - def start(self, *args, **kwargs): + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: r"""Starts the internal task in the event loop. Parameters @@ -302,7 +338,7 @@ class Loop: self._task = self.loop.create_task(self._loop(*args, **kwargs)) return self._task - def stop(self): + def stop(self) -> None: r"""Gracefully stops the task from running. Unlike :meth:`cancel`\, this allows the task to finish its @@ -323,15 +359,15 @@ class Loop: if self._task and not self._task.done(): self._stop_next_iteration = True - def _can_be_cancelled(self): - return not self._is_being_cancelled and self._task and not self._task.done() + def _can_be_cancelled(self) -> bool: + return bool(not self._is_being_cancelled and self._task and not self._task.done()) - def cancel(self): + def cancel(self) -> None: """Cancels the internal task, if it is running.""" if self._can_be_cancelled(): self._task.cancel() - def restart(self, *args, **kwargs): + def restart(self, *args: Any, **kwargs: Any) -> None: r"""A convenience method to restart the internal task. .. note:: @@ -355,7 +391,7 @@ class Loop: self._task.add_done_callback(restart_when_over) self._task.cancel() - def add_exception_type(self, *exceptions): + def add_exception_type(self, *exceptions: Type[BaseException]) -> None: r"""Adds exception types to be handled during the reconnect logic. By default the exception types handled are those handled by @@ -384,7 +420,7 @@ class Loop: self._valid_exception = (*self._valid_exception, *exceptions) - def clear_exception_types(self): + def clear_exception_types(self) -> None: """Removes all exception types that are handled. .. note:: @@ -393,7 +429,7 @@ class Loop: """ self._valid_exception = tuple() - def remove_exception_type(self, *exceptions): + def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool: r"""Removes exception types from being handled during the reconnect logic. Parameters @@ -410,34 +446,34 @@ class Loop: self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) return len(self._valid_exception) == old_length - len(exceptions) - def get_task(self): + def get_task(self) -> Optional[asyncio.Task]: """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" return self._task - def is_being_cancelled(self): + def is_being_cancelled(self) -> bool: """Whether the task is being cancelled.""" return self._is_being_cancelled - def failed(self): + def failed(self) -> bool: """:class:`bool`: Whether the internal task has failed. .. versionadded:: 1.2 """ return self._has_failed - def is_running(self): + def is_running(self) -> bool: """:class:`bool`: Check if the task is currently running. .. versionadded:: 1.4 """ return not bool(self._task.done()) if self._task else False - async def _error(self, *args): - exception = args[-1] + async def _error(self, *args: Any) -> None: + exception: Exception = args[-1] print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) - def before_loop(self, coro): + def before_loop(self, coro: FT) -> FT: """A decorator that registers a coroutine to be called before the loop starts running. This is useful if you want to wait for some bot state before the loop starts, @@ -462,7 +498,7 @@ class Loop: self._before_loop = coro return coro - def after_loop(self, coro): + def after_loop(self, coro: FT) -> FT: """A decorator that register a coroutine to be called after the loop finished running. The coroutine must take no arguments (except ``self`` in a class context). @@ -490,7 +526,7 @@ class Loop: self._after_loop = coro return coro - def error(self, coro): + def error(self, coro: ET) -> ET: """A decorator that registers a coroutine to be called if the task encounters an unhandled exception. The coroutine must take only one argument the exception raised (except ``self`` in a class context). @@ -513,11 +549,11 @@ class Loop: if not inspect.iscoroutinefunction(coro): raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') - self._error = coro + self._error = coro # type: ignore return coro - def _get_next_sleep_time(self): - if self._sleep is not None: + def _get_next_sleep_time(self) -> datetime.datetime: + if self._sleep is not MISSING: return self._last_iteration + datetime.timedelta(seconds=self._sleep) if self._time_index >= len(self._time): @@ -532,7 +568,7 @@ class Loop: self._time_index += 1 return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) - next_date = self._last_iteration + next_date = cast(datetime.datetime, self._last_iteration) if self._time_index == 0: # we can assume that the earliest time should be scheduled for "tomorrow" next_date += datetime.timedelta(days=1) @@ -540,7 +576,7 @@ class Loop: self._time_index += 1 return datetime.datetime.combine(next_date, next_time) - def _prepare_time_index(self, now=None): + def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None: # now kwarg should be a datetime.datetime representing the time "now" # to calculate the next time index from @@ -553,25 +589,38 @@ class Loop: else: self._time_index = 0 - def _get_time_parameter(self, time, *, inst=isinstance, dt=datetime.time, utc=datetime.timezone.utc): - if inst(time, dt): + def _get_time_parameter( + self, + time: Union[datetime.time, Sequence[datetime.time]], + *, + dt: Type[datetime.time] = datetime.time, + utc: datetime.timezone = datetime.timezone.utc, + ) -> List[datetime.time]: + if isinstance(time, dt): ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) return [ret] - if not inst(time, Sequence): + if not isinstance(time, Sequence): raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') if not time: raise ValueError('time parameter must not be an empty sequence.') ret = [] for index, t in enumerate(time): - if not inst(t, dt): + if not isinstance(t, dt): raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret = sorted(set(ret)) # de-dupe and sort times return ret - def change_interval(self, *, seconds=0, minutes=0, hours=0, time=None): + def change_interval( + self, + *, + seconds: float = 0, + minutes: float = 0, + hours: float = 0, + time: Union[datetime.time, Sequence[datetime.time]] = MISSING, + ) -> None: """Changes the interval for the sleep time. .. versionadded:: 1.2 @@ -604,7 +653,10 @@ class Loop: ``time`` parameter was passed in conjunction with relative time parameters. """ - if time is None: + if time is MISSING: + seconds = seconds or 0 + minutes = minutes or 0 + hours = hours or 0 sleep = seconds + (minutes * 60.0) + (hours * 3600.0) if sleep < 0: raise ValueError('Total number of seconds cannot be less than zero.') @@ -613,12 +665,12 @@ class Loop: self._seconds = float(seconds) self._hours = float(hours) self._minutes = float(minutes) - self._time = None + self._time: List[datetime.time] = MISSING else: if any((seconds, minutes, hours)): raise TypeError('Cannot mix explicit time with relative time') self._time = self._get_time_parameter(time) - self._sleep = self._seconds = self._minutes = self._hours = None + self._sleep = self._seconds = self._minutes = self._hours = MISSING if self.is_running(): if self._time is not None: @@ -631,7 +683,16 @@ class Loop: self._handle.recalculate(self._next_iteration) -def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True, loop=None): +def loop( + *, + seconds: float = MISSING, + minutes: float = MISSING, + hours: float = MISSING, + time: Union[datetime.time, Sequence[datetime.time]] = MISSING, + count: Optional[int] = None, + reconnect: bool = True, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -663,7 +724,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True Whether to handle errors and restart the task using an exponential back-off algorithm similar to the one used in :meth:`discord.Client.connect`. - loop: :class:`asyncio.AbstractEventLoop` + loop: Optional[:class:`asyncio.AbstractEventLoop`] The loop to use to register the task, if not given defaults to :func:`asyncio.get_event_loop`. @@ -675,7 +736,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True The function was not a coroutine, an invalid value for the ``time`` parameter was passed, or ``time`` parameter was passed in conjunction with relative time parameters. """ - def decorator(func): + def decorator(func: LF) -> Loop[LF]: kwargs = { 'seconds': seconds, 'minutes': minutes, @@ -683,7 +744,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True 'count': count, 'time': time, 'reconnect': reconnect, - 'loop': loop + 'loop': loop, } return Loop(func, **kwargs) return decorator