Browse Source

[tasks] Type hint the tasks extension

pull/6900/head
Josh 4 years ago
committed by GitHub
parent
commit
ef22178dee
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 191
      discord/ext/tasks/__init__.py

191
discord/ext/tasks/__init__.py

@ -22,8 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
from typing import (
Any,
Awaitable,
Callable,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
import aiohttp import aiohttp
import discord import discord
import inspect import inspect
@ -33,6 +48,7 @@ import traceback
from collections.abc import Sequence from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -40,41 +56,58 @@ __all__ = (
'loop', 'loop',
) )
T = TypeVar('T')
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')
class SleepHandle: class SleepHandle:
__slots__ = ('future', 'loop', 'handle') __slots__ = ('future', 'loop', 'handle')
def __init__(self, dt, *, loop): def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop self.loop = loop
self.future = future = loop.create_future() self.future = future = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True) self.handle = loop.call_later(relative_delta, future.set_result, True)
def recalculate(self, dt): def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel() self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self): def wait(self) -> asyncio.Future:
return self.future return self.future
def done(self): def done(self) -> bool:
return self.future.done() return self.future.done()
def cancel(self): def cancel(self) -> None:
self.handle.cancel() self.handle.cancel()
self.future.cancel() self.future.cancel()
class Loop: class Loop(Generic[LF]):
"""A background task helper that abstracts the loop and reconnection logic for you. """A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`. The main interface to create this is through :func:`loop`.
""" """
def __init__(self, coro, seconds, hours, minutes, time, count, reconnect, loop): def __init__(self,
self.coro = coro coro: LF,
self.reconnect = reconnect seconds: float,
self.loop = loop hours: float,
self.count = count minutes: float,
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop],
) -> None:
self.coro: LF = coro
self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop
self.count: Optional[int] = count
self._current_loop = 0 self._current_loop = 0
self._handle = None self._handle = None
self._task = None self._task = None
@ -104,7 +137,7 @@ class Loop:
if not inspect.iscoroutinefunction(self.coro): if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
async def _call_loop_function(self, name, *args, **kwargs): async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, '_' + name) coro = getattr(self, '_' + name)
if coro is None: if coro is None:
return return
@ -114,16 +147,16 @@ class Loop:
else: else:
await coro(*args, **kwargs) await coro(*args, **kwargs)
def _try_sleep_until(self, dt):
self._handle = SleepHandle(dt=dt, loop=self.loop) def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
return self._handle.wait() return self._handle.wait()
async def _loop(self, *args, **kwargs): async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff() backoff = ExponentialBackoff()
await self._call_loop_function('before_loop') await self._call_loop_function('before_loop')
sleep_until = discord.utils.sleep_until
self._last_iteration_failed = False self._last_iteration_failed = False
if self._time is not None: if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started # the time index should be prepared every time the internal loop is started
self._prepare_time_index() self._prepare_time_index()
self._next_iteration = self._get_next_sleep_time() self._next_iteration = self._get_next_sleep_time()
@ -174,7 +207,7 @@ class Loop:
self._stop_next_iteration = False self._stop_next_iteration = False
self._has_failed = False self._has_failed = False
def __get__(self, obj, objtype): def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
if obj is None: if obj is None:
return self return self
@ -183,8 +216,8 @@ class Loop:
seconds=self._seconds, seconds=self._seconds,
hours=self._hours, hours=self._hours,
minutes=self._minutes, minutes=self._minutes,
count=self.count,
time=self._time, time=self._time,
count=self.count,
reconnect=self.reconnect, reconnect=self.reconnect,
loop=self.loop, loop=self.loop,
) )
@ -196,49 +229,52 @@ class Loop:
return copy return copy
@property @property
def seconds(self): def seconds(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of seconds """Optional[:class:`float`]: Read-only value for the number of seconds
between each iteration. ``None`` if an explicit ``time`` value was passed instead. between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return self._seconds if self._seconds is not MISSING:
return self._seconds
@property @property
def minutes(self): def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes """Optional[:class:`float`]: Read-only value for the number of minutes
between each iteration. ``None`` if an explicit ``time`` value was passed instead. between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return self._minutes if self._minutes is not MISSING:
return self._minutes
@property @property
def hours(self): def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours """Optional[:class:`float`]: Read-only value for the number of hours
between each iteration. ``None`` if an explicit ``time`` value was passed instead. between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return self._hours if self._hours is not MISSING:
return self._hours
@property @property
def time(self): def time(self) -> Optional[List[datetime.time]]:
"""Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at. """Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at.
``None`` if relative times were passed instead. ``None`` if relative times were passed instead.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
if self._time is not None: if self._time is not MISSING:
return self._time.copy() return self._time.copy()
@property @property
def current_loop(self): def current_loop(self) -> int:
""":class:`int`: The current iteration of the loop.""" """:class:`int`: The current iteration of the loop."""
return self._current_loop return self._current_loop
@property @property
def next_iteration(self): def next_iteration(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur. """Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur.
.. versionadded:: 1.3 .. versionadded:: 1.3
@ -249,7 +285,7 @@ class Loop:
return None return None
return self._next_iteration return self._next_iteration
async def __call__(self, *args, **kwargs): async def __call__(self, *args: Any, **kwargs: Any) -> Any:
r"""|coro| r"""|coro|
Calls the internal callback that the task holds. Calls the internal callback that the task holds.
@ -269,7 +305,7 @@ class Loop:
return await self.coro(*args, **kwargs) return await self.coro(*args, **kwargs)
def start(self, *args, **kwargs): def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
r"""Starts the internal task in the event loop. r"""Starts the internal task in the event loop.
Parameters Parameters
@ -302,7 +338,7 @@ class Loop:
self._task = self.loop.create_task(self._loop(*args, **kwargs)) self._task = self.loop.create_task(self._loop(*args, **kwargs))
return self._task return self._task
def stop(self): def stop(self) -> None:
r"""Gracefully stops the task from running. r"""Gracefully stops the task from running.
Unlike :meth:`cancel`\, this allows the task to finish its Unlike :meth:`cancel`\, this allows the task to finish its
@ -323,15 +359,15 @@ class Loop:
if self._task and not self._task.done(): if self._task and not self._task.done():
self._stop_next_iteration = True self._stop_next_iteration = True
def _can_be_cancelled(self): def _can_be_cancelled(self) -> bool:
return not self._is_being_cancelled and self._task and not self._task.done() return bool(not self._is_being_cancelled and self._task and not self._task.done())
def cancel(self): def cancel(self) -> None:
"""Cancels the internal task, if it is running.""" """Cancels the internal task, if it is running."""
if self._can_be_cancelled(): if self._can_be_cancelled():
self._task.cancel() self._task.cancel()
def restart(self, *args, **kwargs): def restart(self, *args: Any, **kwargs: Any) -> None:
r"""A convenience method to restart the internal task. r"""A convenience method to restart the internal task.
.. note:: .. note::
@ -355,7 +391,7 @@ class Loop:
self._task.add_done_callback(restart_when_over) self._task.add_done_callback(restart_when_over)
self._task.cancel() self._task.cancel()
def add_exception_type(self, *exceptions): def add_exception_type(self, *exceptions: Type[BaseException]) -> None:
r"""Adds exception types to be handled during the reconnect logic. r"""Adds exception types to be handled during the reconnect logic.
By default the exception types handled are those handled by By default the exception types handled are those handled by
@ -384,7 +420,7 @@ class Loop:
self._valid_exception = (*self._valid_exception, *exceptions) self._valid_exception = (*self._valid_exception, *exceptions)
def clear_exception_types(self): def clear_exception_types(self) -> None:
"""Removes all exception types that are handled. """Removes all exception types that are handled.
.. note:: .. note::
@ -393,7 +429,7 @@ class Loop:
""" """
self._valid_exception = tuple() self._valid_exception = tuple()
def remove_exception_type(self, *exceptions): def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool:
r"""Removes exception types from being handled during the reconnect logic. r"""Removes exception types from being handled during the reconnect logic.
Parameters Parameters
@ -410,34 +446,34 @@ class Loop:
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions) return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self): def get_task(self) -> Optional[asyncio.Task]:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task return self._task
def is_being_cancelled(self): def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled.""" """Whether the task is being cancelled."""
return self._is_being_cancelled return self._is_being_cancelled
def failed(self): def failed(self) -> bool:
""":class:`bool`: Whether the internal task has failed. """:class:`bool`: Whether the internal task has failed.
.. versionadded:: 1.2 .. versionadded:: 1.2
""" """
return self._has_failed return self._has_failed
def is_running(self): def is_running(self) -> bool:
""":class:`bool`: Check if the task is currently running. """:class:`bool`: Check if the task is currently running.
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
return not bool(self._task.done()) if self._task else False return not bool(self._task.done()) if self._task else False
async def _error(self, *args): async def _error(self, *args: Any) -> None:
exception = args[-1] exception: Exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro): def before_loop(self, coro: FT) -> FT:
"""A decorator that registers a coroutine to be called before the loop starts running. """A decorator that registers a coroutine to be called before the loop starts running.
This is useful if you want to wait for some bot state before the loop starts, This is useful if you want to wait for some bot state before the loop starts,
@ -462,7 +498,7 @@ class Loop:
self._before_loop = coro self._before_loop = coro
return coro return coro
def after_loop(self, coro): def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running. """A decorator that register a coroutine to be called after the loop finished running.
The coroutine must take no arguments (except ``self`` in a class context). The coroutine must take no arguments (except ``self`` in a class context).
@ -490,7 +526,7 @@ class Loop:
self._after_loop = coro self._after_loop = coro
return coro return coro
def error(self, coro): def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception. """A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
The coroutine must take only one argument the exception raised (except ``self`` in a class context). The coroutine must take only one argument the exception raised (except ``self`` in a class context).
@ -513,11 +549,11 @@ class Loop:
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro self._error = coro # type: ignore
return coro return coro
def _get_next_sleep_time(self): def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not None: if self._sleep is not MISSING:
return self._last_iteration + datetime.timedelta(seconds=self._sleep) return self._last_iteration + datetime.timedelta(seconds=self._sleep)
if self._time_index >= len(self._time): if self._time_index >= len(self._time):
@ -532,7 +568,7 @@ class Loop:
self._time_index += 1 self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = self._last_iteration next_date = cast(datetime.datetime, self._last_iteration)
if self._time_index == 0: if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow" # we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1) next_date += datetime.timedelta(days=1)
@ -540,7 +576,7 @@ class Loop:
self._time_index += 1 self._time_index += 1
return datetime.datetime.combine(next_date, next_time) return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now=None): def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
# now kwarg should be a datetime.datetime representing the time "now" # now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from # to calculate the next time index from
@ -553,25 +589,38 @@ class Loop:
else: else:
self._time_index = 0 self._time_index = 0
def _get_time_parameter(self, time, *, inst=isinstance, dt=datetime.time, utc=datetime.timezone.utc): def _get_time_parameter(
if inst(time, dt): self,
time: Union[datetime.time, Sequence[datetime.time]],
*,
dt: Type[datetime.time] = datetime.time,
utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]:
if isinstance(time, dt):
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [ret] return [ret]
if not inst(time, Sequence): if not isinstance(time, Sequence):
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
if not time: if not time:
raise ValueError('time parameter must not be an empty sequence.') raise ValueError('time parameter must not be an empty sequence.')
ret = [] ret = []
for index, t in enumerate(time): for index, t in enumerate(time):
if not inst(t, dt): if not isinstance(t, dt):
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times ret = sorted(set(ret)) # de-dupe and sort times
return ret return ret
def change_interval(self, *, seconds=0, minutes=0, hours=0, time=None): def change_interval(
self,
*,
seconds: float = 0,
minutes: float = 0,
hours: float = 0,
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
) -> None:
"""Changes the interval for the sleep time. """Changes the interval for the sleep time.
.. versionadded:: 1.2 .. versionadded:: 1.2
@ -604,7 +653,10 @@ class Loop:
``time`` parameter was passed in conjunction with relative time parameters. ``time`` parameter was passed in conjunction with relative time parameters.
""" """
if time is None: if time is MISSING:
seconds = seconds or 0
minutes = minutes or 0
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0) sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0: if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.') raise ValueError('Total number of seconds cannot be less than zero.')
@ -613,12 +665,12 @@ class Loop:
self._seconds = float(seconds) self._seconds = float(seconds)
self._hours = float(hours) self._hours = float(hours)
self._minutes = float(minutes) self._minutes = float(minutes)
self._time = None self._time: List[datetime.time] = MISSING
else: else:
if any((seconds, minutes, hours)): if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time') raise TypeError('Cannot mix explicit time with relative time')
self._time = self._get_time_parameter(time) self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = None self._sleep = self._seconds = self._minutes = self._hours = MISSING
if self.is_running(): if self.is_running():
if self._time is not None: if self._time is not None:
@ -631,7 +683,16 @@ class Loop:
self._handle.recalculate(self._next_iteration) self._handle.recalculate(self._next_iteration)
def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True, loop=None): def loop(
*,
seconds: float = MISSING,
minutes: float = MISSING,
hours: float = MISSING,
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@ -663,7 +724,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True
Whether to handle errors and restart the task Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`. one used in :meth:`discord.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop` loop: Optional[:class:`asyncio.AbstractEventLoop`]
The loop to use to register the task, if not given The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`. defaults to :func:`asyncio.get_event_loop`.
@ -675,7 +736,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters. or ``time`` parameter was passed in conjunction with relative time parameters.
""" """
def decorator(func): def decorator(func: LF) -> Loop[LF]:
kwargs = { kwargs = {
'seconds': seconds, 'seconds': seconds,
'minutes': minutes, 'minutes': minutes,
@ -683,7 +744,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True
'count': count, 'count': count,
'time': time, 'time': time,
'reconnect': reconnect, 'reconnect': reconnect,
'loop': loop 'loop': loop,
} }
return Loop(func, **kwargs) return Loop(func, **kwargs)
return decorator return decorator

Loading…
Cancel
Save