Browse Source

[tasks] Improve typing parity

pull/7488/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
a2a7b0f076
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 96
      discord/ext/tasks/__init__.py

96
discord/ext/tasks/__init__.py

@ -36,13 +36,11 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast,
) )
import aiohttp import aiohttp
import discord import discord
import inspect import inspect
import logging
import sys import sys
import traceback import traceback
@ -50,8 +48,6 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
_log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'loop', 'loop',
) )
@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func) LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func) FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')
class SleepHandle: class SleepHandle:
@ -78,7 +73,7 @@ class SleepHandle:
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future: def wait(self) -> asyncio.Future[Any]:
return self.future return self.future
def done(self) -> bool: def done(self) -> bool:
@ -94,7 +89,9 @@ class Loop(Generic[LF]):
The main interface to create this is through :func:`loop`. The main interface to create this is through :func:`loop`.
""" """
def __init__(self,
def __init__(
self,
coro: LF, coro: LF,
seconds: float, seconds: float,
hours: float, hours: float,
@ -102,15 +99,15 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]], time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int], count: Optional[int],
reconnect: bool, reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop], loop: asyncio.AbstractEventLoop,
) -> None: ) -> None:
self.coro: LF = coro self.coro: LF = coro
self.reconnect: bool = reconnect self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count self.count: Optional[int] = count
self._current_loop = 0 self._current_loop = 0
self._handle = None self._handle: SleepHandle = MISSING
self._task = None self._task: asyncio.Task[None] = MISSING
self._injected = None self._injected = None
self._valid_exception = ( self._valid_exception = (
OSError, OSError,
@ -131,7 +128,7 @@ class Loop(Generic[LF]):
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False self._last_iteration_failed = False
self._last_iteration = None self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro): if not inspect.iscoroutinefunction(self.coro):
@ -147,9 +144,8 @@ class Loop(Generic[LF]):
else: else:
await coro(*args, **kwargs) await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime): def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore self._handle = SleepHandle(dt=dt, loop=self.loop)
return self._handle.wait() return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None: async def _loop(self, *args: Any, **kwargs: Any) -> None:
@ -211,7 +207,7 @@ class Loop(Generic[LF]):
if obj is None: if obj is None:
return self return self
copy = Loop( copy: Loop[LF] = Loop(
self.coro, self.coro,
seconds=self._seconds, seconds=self._seconds,
hours=self._hours, hours=self._hours,
@ -279,7 +275,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
if self._task is None: if self._task is MISSING:
return None return None
elif self._task and self._task.done() or self._stop_next_iteration: elif self._task and self._task.done() or self._stop_next_iteration:
return None return None
@ -305,7 +301,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs) return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
r"""Starts the internal task in the event loop. r"""Starts the internal task in the event loop.
Parameters Parameters
@ -326,13 +322,13 @@ class Loop(Generic[LF]):
The task that has been created. The task that has been created.
""" """
if self._task is not None and not self._task.done(): if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.') raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None: if self._injected is not None:
args = (self._injected, *args) args = (self._injected, *args)
if self.loop is None: if self.loop is MISSING:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs)) self._task = self.loop.create_task(self._loop(*args, **kwargs))
@ -356,7 +352,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2 .. versionadded:: 1.2
""" """
if self._task and not self._task.done(): if self._task is not MISSING and not self._task.done():
self._stop_next_iteration = True self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool: def _can_be_cancelled(self) -> bool:
@ -383,7 +379,7 @@ class Loop(Generic[LF]):
The keyword arguments to use. The keyword arguments to use.
""" """
def restart_when_over(fut, *, args=args, kwargs=kwargs): def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
self._task.remove_done_callback(restart_when_over) self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs) self.start(*args, **kwargs)
@ -446,9 +442,9 @@ class Loop(Generic[LF]):
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions) return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task]: def get_task(self) -> Optional[asyncio.Task[None]]:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task return self._task if self._task is not MISSING else None
def is_being_cancelled(self) -> bool: def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled.""" """Whether the task is being cancelled."""
@ -466,7 +462,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
return not bool(self._task.done()) if self._task else False return not bool(self._task.done()) if self._task is not MISSING else False
async def _error(self, *args: Any) -> None: async def _error(self, *args: Any) -> None:
exception: Exception = args[-1] exception: Exception = args[-1]
@ -560,7 +556,9 @@ class Loop(Generic[LF]):
self._time_index = 0 self._time_index = 0
if self._current_loop == 0: if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow # if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]) return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index] next_time = self._time[self._time_index]
@ -568,7 +566,7 @@ class Loop(Generic[LF]):
self._time_index += 1 self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = cast(datetime.datetime, self._last_iteration) next_date = self._last_iteration
if self._time_index == 0: if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow" # we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1) next_date += datetime.timedelta(days=1)
@ -576,12 +574,14 @@ class Loop(Generic[LF]):
self._time_index += 1 self._time_index += 1
return datetime.datetime.combine(next_date, next_time) return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None: def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
# now kwarg should be a datetime.datetime representing the time "now" # now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from # to calculate the next time index from
# pre-condition: self._time is set # pre-condition: self._time is set
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz() time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time): for idx, time in enumerate(self._time):
if time >= time_now: if time >= time_now:
self._time_index = idx self._time_index = idx
@ -597,17 +597,21 @@ class Loop(Generic[LF]):
utc: datetime.timezone = datetime.timezone.utc, utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]: ) -> List[datetime.time]:
if isinstance(time, dt): if isinstance(time, dt):
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [ret] return [inner]
if not isinstance(time, Sequence): if not isinstance(time, Sequence):
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
)
if not time: if not time:
raise ValueError('time parameter must not be an empty sequence.') raise ValueError('time parameter must not be an empty sequence.')
ret = [] ret: List[datetime.time] = []
for index, t in enumerate(time): for index, t in enumerate(time):
if not isinstance(t, dt): if not isinstance(t, dt):
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times ret = sorted(set(ret)) # de-dupe and sort times
@ -691,7 +695,7 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING, time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None, count: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]: ) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@ -724,7 +728,7 @@ def loop(
Whether to handle errors and restart the task Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`. one used in :meth:`discord.Client.connect`.
loop: Optional[:class:`asyncio.AbstractEventLoop`] loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`. defaults to :func:`asyncio.get_event_loop`.
@ -736,15 +740,17 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters. or ``time`` parameter was passed in conjunction with relative time parameters.
""" """
def decorator(func: LF) -> Loop[LF]: def decorator(func: LF) -> Loop[LF]:
kwargs = { return Loop[LF](
'seconds': seconds, func,
'minutes': minutes, seconds=seconds,
'hours': hours, minutes=minutes,
'count': count, hours=hours,
'time': time, count=count,
'reconnect': reconnect, time=time,
'loop': loop, reconnect=reconnect,
} loop=loop,
return Loop(func, **kwargs) )
return decorator return decorator

Loading…
Cancel
Save