Browse Source

[tasks] Improve typing parity

pull/7488/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
a2a7b0f076
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 128
      discord/ext/tasks/__init__.py

128
discord/ext/tasks/__init__.py

@ -27,22 +27,20 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Generic, Generic,
List, List,
Optional, Optional,
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast,
) )
import aiohttp import aiohttp
import discord import discord
import inspect import inspect
import logging
import sys import sys
import traceback import traceback
@ -50,8 +48,6 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
_log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'loop', 'loop',
) )
@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func) LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func) FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')
class SleepHandle: class SleepHandle:
@ -78,7 +73,7 @@ class SleepHandle:
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future: def wait(self) -> asyncio.Future[Any]:
return self.future return self.future
def done(self) -> bool: def done(self) -> bool:
@ -94,7 +89,9 @@ class Loop(Generic[LF]):
The main interface to create this is through :func:`loop`. The main interface to create this is through :func:`loop`.
""" """
def __init__(self,
def __init__(
self,
coro: LF, coro: LF,
seconds: float, seconds: float,
hours: float, hours: float,
@ -102,15 +99,15 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]], time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int], count: Optional[int],
reconnect: bool, reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop], loop: asyncio.AbstractEventLoop,
) -> None: ) -> None:
self.coro: LF = coro self.coro: LF = coro
self.reconnect: bool = reconnect self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count self.count: Optional[int] = count
self._current_loop = 0 self._current_loop = 0
self._handle = None self._handle: SleepHandle = MISSING
self._task = None self._task: asyncio.Task[None] = MISSING
self._injected = None self._injected = None
self._valid_exception = ( self._valid_exception = (
OSError, OSError,
@ -131,7 +128,7 @@ class Loop(Generic[LF]):
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False self._last_iteration_failed = False
self._last_iteration = None self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro): if not inspect.iscoroutinefunction(self.coro):
@ -147,9 +144,8 @@ class Loop(Generic[LF]):
else: else:
await coro(*args, **kwargs) await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime): def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore self._handle = SleepHandle(dt=dt, loop=self.loop)
return self._handle.wait() return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None: async def _loop(self, *args: Any, **kwargs: Any) -> None:
@ -178,7 +174,7 @@ class Loop(Generic[LF]):
await asyncio.sleep(backoff.delay()) await asyncio.sleep(backoff.delay())
else: else:
await self._try_sleep_until(self._next_iteration) await self._try_sleep_until(self._next_iteration)
if self._stop_next_iteration: if self._stop_next_iteration:
return return
@ -211,14 +207,14 @@ class Loop(Generic[LF]):
if obj is None: if obj is None:
return self return self
copy = Loop( copy: Loop[LF] = Loop(
self.coro, self.coro,
seconds=self._seconds, seconds=self._seconds,
hours=self._hours, hours=self._hours,
minutes=self._minutes, minutes=self._minutes,
time=self._time, time=self._time,
count=self.count, count=self.count,
reconnect=self.reconnect, reconnect=self.reconnect,
loop=self.loop, loop=self.loop,
) )
copy._injected = obj copy._injected = obj
@ -237,7 +233,7 @@ class Loop(Generic[LF]):
""" """
if self._seconds is not MISSING: if self._seconds is not MISSING:
return self._seconds return self._seconds
@property @property
def minutes(self) -> Optional[float]: def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes """Optional[:class:`float`]: Read-only value for the number of minutes
@ -247,7 +243,7 @@ class Loop(Generic[LF]):
""" """
if self._minutes is not MISSING: if self._minutes is not MISSING:
return self._minutes return self._minutes
@property @property
def hours(self) -> Optional[float]: def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours """Optional[:class:`float`]: Read-only value for the number of hours
@ -279,7 +275,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
if self._task is None: if self._task is MISSING:
return None return None
elif self._task and self._task.done() or self._stop_next_iteration: elif self._task and self._task.done() or self._stop_next_iteration:
return None return None
@ -305,7 +301,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs) return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
r"""Starts the internal task in the event loop. r"""Starts the internal task in the event loop.
Parameters Parameters
@ -326,13 +322,13 @@ class Loop(Generic[LF]):
The task that has been created. The task that has been created.
""" """
if self._task is not None and not self._task.done(): if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.') raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None: if self._injected is not None:
args = (self._injected, *args) args = (self._injected, *args)
if self.loop is None: if self.loop is MISSING:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs)) self._task = self.loop.create_task(self._loop(*args, **kwargs))
@ -356,7 +352,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2 .. versionadded:: 1.2
""" """
if self._task and not self._task.done(): if self._task is not MISSING and not self._task.done():
self._stop_next_iteration = True self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool: def _can_be_cancelled(self) -> bool:
@ -383,7 +379,7 @@ class Loop(Generic[LF]):
The keyword arguments to use. The keyword arguments to use.
""" """
def restart_when_over(fut, *, args=args, kwargs=kwargs): def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
self._task.remove_done_callback(restart_when_over) self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs) self.start(*args, **kwargs)
@ -446,9 +442,9 @@ class Loop(Generic[LF]):
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions) return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task]: def get_task(self) -> Optional[asyncio.Task[None]]:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task return self._task if self._task is not MISSING else None
def is_being_cancelled(self) -> bool: def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled.""" """Whether the task is being cancelled."""
@ -466,7 +462,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
return not bool(self._task.done()) if self._task else False return not bool(self._task.done()) if self._task is not MISSING else False
async def _error(self, *args: Any) -> None: async def _error(self, *args: Any) -> None:
exception: Exception = args[-1] exception: Exception = args[-1]
@ -560,7 +556,9 @@ class Loop(Generic[LF]):
self._time_index = 0 self._time_index = 0
if self._current_loop == 0: if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow # if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]) return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index] next_time = self._time[self._time_index]
@ -568,7 +566,7 @@ class Loop(Generic[LF]):
self._time_index += 1 self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = cast(datetime.datetime, self._last_iteration) next_date = self._last_iteration
if self._time_index == 0: if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow" # we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1) next_date += datetime.timedelta(days=1)
@ -576,12 +574,14 @@ class Loop(Generic[LF]):
self._time_index += 1 self._time_index += 1
return datetime.datetime.combine(next_date, next_time) return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None: def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
# now kwarg should be a datetime.datetime representing the time "now" # now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from # to calculate the next time index from
# pre-condition: self._time is set # pre-condition: self._time is set
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz() time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time): for idx, time in enumerate(self._time):
if time >= time_now: if time >= time_now:
self._time_index = idx self._time_index = idx
@ -597,20 +597,24 @@ class Loop(Generic[LF]):
utc: datetime.timezone = datetime.timezone.utc, utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]: ) -> List[datetime.time]:
if isinstance(time, dt): if isinstance(time, dt):
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [ret] return [inner]
if not isinstance(time, Sequence): if not isinstance(time, Sequence):
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
)
if not time: if not time:
raise ValueError('time parameter must not be an empty sequence.') raise ValueError('time parameter must not be an empty sequence.')
ret = [] ret: List[datetime.time] = []
for index, t in enumerate(time): for index, t in enumerate(time):
if not isinstance(t, dt): if not isinstance(t, dt):
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times ret = sorted(set(ret)) # de-dupe and sort times
return ret return ret
def change_interval( def change_interval(
@ -691,7 +695,7 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING, time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None, count: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]: ) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@ -707,7 +711,7 @@ def loop(
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]] time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed. Timezones are supported. value of :class:`datetime.time` should be passed. Timezones are supported.
If no timezone is given for the times, it is assumed to represent UTC time. If no timezone is given for the times, it is assumed to represent UTC time.
This cannot be used in conjunction with the relative time parameters. This cannot be used in conjunction with the relative time parameters.
@ -724,7 +728,7 @@ def loop(
Whether to handle errors and restart the task Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`. one used in :meth:`discord.Client.connect`.
loop: Optional[:class:`asyncio.AbstractEventLoop`] loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`. defaults to :func:`asyncio.get_event_loop`.
@ -736,15 +740,17 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters. or ``time`` parameter was passed in conjunction with relative time parameters.
""" """
def decorator(func: LF) -> Loop[LF]: def decorator(func: LF) -> Loop[LF]:
kwargs = { return Loop[LF](
'seconds': seconds, func,
'minutes': minutes, seconds=seconds,
'hours': hours, minutes=minutes,
'count': count, hours=hours,
'time': time, count=count,
'reconnect': reconnect, time=time,
'loop': loop, reconnect=reconnect,
} loop=loop,
return Loop(func, **kwargs) )
return decorator return decorator

Loading…
Cancel
Save