From d0667d08e356256bf87cf0151f0f4d2f94e570f7 Mon Sep 17 00:00:00 2001 From: Michael H Date: Mon, 2 May 2022 18:54:49 -0400 Subject: [PATCH] [commands] Fix typing of check/check_any This changes the type information of check decorators to return a protocol representing that the decorator leaves the underlying object unchanged while having a .predicate attribute. resolves #7949 --- discord/ext/commands/_types.py | 14 ++++++++++++-- discord/ext/commands/bot.py | 10 +++++----- discord/ext/commands/core.py | 32 ++++++++++++++++---------------- discord/ext/commands/help.py | 6 +++--- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index b70a4bfdb..a048eb78f 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional +from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional T = TypeVar('T') @@ -49,13 +49,23 @@ MaybeCoro = Union[T, Coro[T]] MaybeAwaitable = Union[T, Awaitable[T]] CogT = TypeVar('CogT', bound='Optional[Cog]') -Check = Callable[["ContextT"], MaybeCoro[bool]] +UserCheck = Callable[["ContextT"], MaybeCoro[bool]] Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]] ContextT = TypeVar('ContextT', bound='Context[Any]') BotT = TypeVar('BotT', bound=_Bot, covariant=True) +ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True) + + +class Check(Protocol[ContextT_co]): + + predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]] + + def __call__(self, coro_or_commands: T) -> T: + ... + # This is merely a tag type to avoid circular import issues. # Yes, this is a terrible solution but ultimately it is the only solution. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index ed65ebda5..983596ef9 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -73,7 +73,7 @@ if TYPE_CHECKING: from ._types import ( _Bot, BotT, - Check, + UserCheck, CoroFunc, ContextT, MaybeAwaitableFunc, @@ -173,8 +173,8 @@ class BotBase(GroupMixin[None]): self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} - self._checks: List[Check] = [] - self._check_once: List[Check] = [] + self._checks: List[UserCheck] = [] + self._check_once: List[UserCheck] = [] self._before_invoke: Optional[CoroFunc] = None self._after_invoke: Optional[CoroFunc] = None self._help_command: Optional[HelpCommand] = None @@ -359,7 +359,7 @@ class BotBase(GroupMixin[None]): self.add_check(func) # type: ignore return func - def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: + def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None: """Adds a global check to the bot. This is the non-decorator interface to :meth:`.check` @@ -383,7 +383,7 @@ class BotBase(GroupMixin[None]): else: self._checks.append(func) - def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: + def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None: """Removes a global check from the bot. This function is idempotent and will not raise an exception diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 6f6296fba..1f467ab91 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -60,7 +60,7 @@ if TYPE_CHECKING: from discord.message import Message - from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook + from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck __all__ = ( @@ -378,7 +378,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: checks = kwargs.get('checks', []) - self.checks: List[Check[ContextT]] = checks + self.checks: List[UserCheck[ContextT]] = checks try: cooldown = func.__commands_cooldown__ @@ -458,7 +458,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) - def add_check(self, func: Check[ContextT], /) -> None: + def add_check(self, func: UserCheck[ContextT], /) -> None: """Adds a check to the command. This is the non-decorator interface to :func:`.check`. @@ -477,7 +477,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.checks.append(func) - def remove_check(self, func: Check[ContextT], /) -> None: + def remove_check(self, func: UserCheck[ContextT], /) -> None: """Removes a check from the command. This function is idempotent and will not raise an exception @@ -1745,7 +1745,7 @@ def group( return command(name=name, cls=cls, **attrs) -def check(predicate: Check[ContextT], /) -> Callable[[T], T]: +def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1844,7 +1844,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]: return decorator # type: ignore -def check_any(*checks: Check[ContextT]) -> Callable[[T], T]: +def check_any(*checks: Check[ContextT]) -> Check[ContextT]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1910,10 +1910,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]: # if we're here, all checks failed raise CheckAnyFailure(unwrapped, errors) - return check(predicate) + return check(predicate) # type: ignore -def has_role(item: Union[int, str], /) -> Callable[[T], T]: +def has_role(item: Union[int, str], /) -> Check[Any]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -2066,7 +2066,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: return check(predicate) -def has_permissions(**perms: bool) -> Callable[[T], T]: +def has_permissions(**perms: bool) -> Check[Any]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. @@ -2114,7 +2114,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def bot_has_permissions(**perms: bool) -> Callable[[T], T]: +def bot_has_permissions(**perms: bool) -> Check[Any]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. @@ -2141,7 +2141,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def has_guild_permissions(**perms: bool) -> Callable[[T], T]: +def has_guild_permissions(**perms: bool) -> Check[Any]: """Similar to :func:`.has_permissions`, but operates on guild wide permissions instead of the current channel permissions. @@ -2170,7 +2170,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: +def bot_has_guild_permissions(**perms: bool) -> Check[Any]: """Similar to :func:`.has_guild_permissions`, but checks the bot members guild permissions. @@ -2196,7 +2196,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def dm_only() -> Callable[[T], T]: +def dm_only() -> Check[Any]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when using the command. @@ -2215,7 +2215,7 @@ def dm_only() -> Callable[[T], T]: return check(predicate) -def guild_only() -> Callable[[T], T]: +def guild_only() -> Check[Any]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when using the command. @@ -2232,7 +2232,7 @@ def guild_only() -> Callable[[T], T]: return check(predicate) -def is_owner() -> Callable[[T], T]: +def is_owner() -> Check[Any]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -2250,7 +2250,7 @@ def is_owner() -> Callable[[T], T]: return check(predicate) -def is_nsfw() -> Callable[[T], T]: +def is_nsfw() -> Check[Any]: """A :func:`.check` that checks if the channel is a NSFW channel. This check raises a special exception, :exc:`.NSFWChannelRequired` diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 6dbd15f9d..c43ad4c9c 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -60,7 +60,7 @@ if TYPE_CHECKING: from .parameters import Parameter from ._types import ( - Check, + UserCheck, ContextT, BotT, _Bot, @@ -378,7 +378,7 @@ class HelpCommand: bot.remove_command(self._command_impl.name) self._command_impl._eject_cog() - def add_check(self, func: Check[ContextT], /) -> None: + def add_check(self, func: UserCheck[ContextT], /) -> None: """ Adds a check to the help command. @@ -396,7 +396,7 @@ class HelpCommand: self._command_impl.add_check(func) - def remove_check(self, func: Check[ContextT], /) -> None: + def remove_check(self, func: UserCheck[ContextT], /) -> None: """ Removes a check from the help command.