diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index be97cc9a3..7f2ff6ec4 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional +from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional T = TypeVar('T') @@ -51,13 +51,23 @@ MaybeCoro = Union[T, Coro[T]] MaybeAwaitable = Union[T, Awaitable[T]] CogT = TypeVar('CogT', bound='Optional[Cog]') -Check = Callable[["ContextT"], MaybeCoro[bool]] +UserCheck = Callable[["ContextT"], MaybeCoro[bool]] Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]] ContextT = TypeVar('ContextT', bound='Context[Any]') BotT = TypeVar('BotT', bound=_Bot, covariant=True) +ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True) + + +class Check(Protocol[ContextT_co]): + + predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]] + + def __call__(self, coro_or_commands: T) -> T: + ... + # This is merely a tag type to avoid circular import issues. # Yes, this is a terrible solution but ultimately it is the only solution. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 2146a2213..a37b36d0c 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -69,7 +69,7 @@ if TYPE_CHECKING: from ._types import ( _Bot, BotT, - Check, + UserCheck, CoroFunc, ContextT, MaybeAwaitableFunc, @@ -161,8 +161,8 @@ class BotBase(GroupMixin[None]): self.extra_events: Dict[str, List[CoroFunc]] = {} self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} - self._checks: List[Check] = [] - self._check_once: List[Check] = [] + self._checks: List[UserCheck] = [] + self._check_once: List[UserCheck] = [] self._before_invoke: Optional[CoroFunc] = None self._after_invoke: Optional[CoroFunc] = None self._help_command: Optional[HelpCommand] = None @@ -282,7 +282,7 @@ class BotBase(GroupMixin[None]): self.add_check(func) # type: ignore return func - def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: + def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None: """Adds a global check to the bot. This is the non-decorator interface to :meth:`.check` @@ -306,7 +306,7 @@ class BotBase(GroupMixin[None]): else: self._checks.append(func) - def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: + def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None: """Removes a global check from the bot. This function is idempotent and will not raise an exception diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index e5b1f0256..6561f9020 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -60,7 +60,7 @@ if TYPE_CHECKING: from discord.message import Message - from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook + from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck __all__ = ( @@ -375,7 +375,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: checks = kwargs.get('checks', []) - self.checks: List[Check[ContextT]] = checks + self.checks: List[UserCheck[ContextT]] = checks try: cooldown = func.__commands_cooldown__ @@ -447,7 +447,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) - def add_check(self, func: Check[ContextT], /) -> None: + def add_check(self, func: UserCheck[ContextT], /) -> None: """Adds a check to the command. This is the non-decorator interface to :func:`.check`. @@ -466,7 +466,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.checks.append(func) - def remove_check(self, func: Check[ContextT], /) -> None: + def remove_check(self, func: UserCheck[ContextT], /) -> None: """Removes a check from the command. This function is idempotent and will not raise an exception @@ -1734,7 +1734,7 @@ def group( return command(name=name, cls=cls, **attrs) -def check(predicate: Check[ContextT], /) -> Callable[[T], T]: +def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1833,7 +1833,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]: return decorator # type: ignore -def check_any(*checks: Check[ContextT]) -> Callable[[T], T]: +def check_any(*checks: Check[ContextT]) -> Check[ContextT]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1899,10 +1899,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]: # if we're here, all checks failed raise CheckAnyFailure(unwrapped, errors) - return check(predicate) + return check(predicate) # type: ignore -def has_role(item: Union[int, str], /) -> Callable[[T], T]: +def has_role(item: Union[int, str], /) -> Check[Any]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -2055,7 +2055,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: return check(predicate) -def has_permissions(**perms: bool) -> Callable[[T], T]: +def has_permissions(**perms: bool) -> Check[Any]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. @@ -2103,7 +2103,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def bot_has_permissions(**perms: bool) -> Callable[[T], T]: +def bot_has_permissions(**perms: bool) -> Check[Any]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. @@ -2130,7 +2130,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def has_guild_permissions(**perms: bool) -> Callable[[T], T]: +def has_guild_permissions(**perms: bool) -> Check[Any]: """Similar to :func:`.has_permissions`, but operates on guild wide permissions instead of the current channel permissions. @@ -2159,7 +2159,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: +def bot_has_guild_permissions(**perms: bool) -> Check[Any]: """Similar to :func:`.has_guild_permissions`, but checks the bot members guild permissions. @@ -2185,7 +2185,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) -def dm_only() -> Callable[[T], T]: +def dm_only() -> Check[Any]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when using the command. @@ -2204,7 +2204,7 @@ def dm_only() -> Callable[[T], T]: return check(predicate) -def guild_only() -> Callable[[T], T]: +def guild_only() -> Check[Any]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when using the command. @@ -2221,7 +2221,7 @@ def guild_only() -> Callable[[T], T]: return check(predicate) -def is_owner() -> Callable[[T], T]: +def is_owner() -> Check[Any]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -2239,7 +2239,7 @@ def is_owner() -> Callable[[T], T]: return check(predicate) -def is_nsfw() -> Callable[[T], T]: +def is_nsfw() -> Check[Any]: """A :func:`.check` that checks if the channel is a NSFW channel. This check raises a special exception, :exc:`.NSFWChannelRequired` diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 5f1e9bcfa..d9c173991 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -61,7 +61,7 @@ if TYPE_CHECKING: from .parameters import Parameter from ._types import ( - Check, + UserCheck, ContextT, BotT, _Bot, @@ -377,7 +377,7 @@ class HelpCommand: bot.remove_command(self._command_impl.name) self._command_impl._eject_cog() - def add_check(self, func: Check[ContextT], /) -> None: + def add_check(self, func: UserCheck[ContextT], /) -> None: """ Adds a check to the help command. @@ -395,7 +395,7 @@ class HelpCommand: self._command_impl.add_check(func) - def remove_check(self, func: Check[ContextT], /) -> None: + def remove_check(self, func: UserCheck[ContextT], /) -> None: """ Removes a check from the help command.