diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 89077664b..5c57a330e 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple +from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple T = TypeVar('T') @@ -37,18 +37,16 @@ if TYPE_CHECKING: from .errors import CommandError P = ParamSpec('P') - MaybeCoroFunc = Union[ - Callable[P, 'Coro[T]'], - Callable[P, T], - ] + MaybeAwaitableFunc = Callable[P, 'MaybeAwaitable[T]'] else: P = TypeVar('P') - MaybeCoroFunc = Tuple[P, T] + MaybeAwaitableFunc = Tuple[P, T] _Bot = Union['Bot', 'AutoShardedBot'] Coro = Coroutine[Any, Any, T] -MaybeCoro = Union[T, Coro[T]] CoroFunc = Callable[..., Coro[Any]] +MaybeCoro = Union[T, Coro[T]] +MaybeAwaitable = Union[T, Awaitable[T]] Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]] Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index dde7545bd..d52194645 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -74,11 +74,11 @@ if TYPE_CHECKING: Check, CoroFunc, ContextT, - MaybeCoroFunc, + MaybeAwaitableFunc, ) _Prefix = Union[Iterable[str], str] - _PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix] + _PrefixCallable = MaybeAwaitableFunc[[BotT, Message], _Prefix] PrefixType = Union[_Prefix, _PrefixCallable[BotT]] __all__ = ( diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index b0c3f7dc6..bf0e12104 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -31,6 +31,7 @@ import re from typing import ( TYPE_CHECKING, + Awaitable, Optional, Generator, Generic, @@ -54,7 +55,6 @@ if TYPE_CHECKING: import discord.abc - from ._types import Coro from .bot import BotBase from .cog import Cog from .context import Context @@ -295,7 +295,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): bot.remove_command(self.name) self._eject_cog() - async def _call_without_cog(self, callback: Callable[[ContextT], Coro[T]], ctx: ContextT) -> T: + async def _call_without_cog(self, callback: Callable[[ContextT], Awaitable[T]], ctx: ContextT) -> T: cog = self._cog self.cog = None try: diff --git a/discord/member.py b/discord/member.py index a8d1ca486..0b5ad4cbc 100644 --- a/discord/member.py +++ b/discord/member.py @@ -28,7 +28,7 @@ import datetime import inspect import itertools from operator import attrgetter -from typing import Any, Callable, Collection, Coroutine, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, Type +from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, Type import discord.abc @@ -331,7 +331,7 @@ class Member(discord.abc.Messageable, _UserTag): default_avatar: Asset avatar: Optional[Asset] dm_channel: Optional[DMChannel] - create_dm: Callable[[], Coroutine[Any, Any, DMChannel]] + create_dm: Callable[[], Awaitable[DMChannel]] mutual_guilds: List[Guild] public_flags: PublicUserFlags banner: Optional[Asset] diff --git a/discord/utils.py b/discord/utils.py index aaa18b780..a3997930c 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -140,10 +140,7 @@ if TYPE_CHECKING: P = ParamSpec('P') - MaybeCoroFunc = Union[ - Callable[P, Coroutine[Any, Any, 'T']], - Callable[P, 'T'], - ] + MaybeAwaitableFunc = Callable[P, 'MaybeAwaitable[T]'] _SnowflakeListBase = array.array[int] @@ -156,6 +153,7 @@ T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) _Iter = Union[Iterable[T], AsyncIterable[T]] Coro = Coroutine[Any, Any, T] +MaybeAwaitable = Union[T, Awaitable[T]] class CachedSlotProperty(Generic[T, T_co]): @@ -615,7 +613,7 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) -async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T: +async def maybe_coroutine(f: MaybeAwaitableFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T: value = f(*args, **kwargs) if _isawaitable(value): return await value