Browse Source

[commands] Fix typing of check/check_any

This changes the type information of check decorators to return a
protocol representing that the decorator leaves the underlying object
unchanged while having a .predicate attribute.

resolves #7949
pull/7987/head
Michael H 3 years ago
committed by GitHub
parent
commit
d0667d08e3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      discord/ext/commands/_types.py
  2. 10
      discord/ext/commands/bot.py
  3. 32
      discord/ext/commands/core.py
  4. 6
      discord/ext/commands/help.py

14
discord/ext/commands/_types.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
""" """
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional
T = TypeVar('T') T = TypeVar('T')
@ -49,13 +49,23 @@ MaybeCoro = Union[T, Coro[T]]
MaybeAwaitable = Union[T, Awaitable[T]] MaybeAwaitable = Union[T, Awaitable[T]]
CogT = TypeVar('CogT', bound='Optional[Cog]') CogT = TypeVar('CogT', bound='Optional[Cog]')
Check = Callable[["ContextT"], MaybeCoro[bool]] UserCheck = Callable[["ContextT"], MaybeCoro[bool]]
Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]] Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
ContextT = TypeVar('ContextT', bound='Context[Any]') ContextT = TypeVar('ContextT', bound='Context[Any]')
BotT = TypeVar('BotT', bound=_Bot, covariant=True) BotT = TypeVar('BotT', bound=_Bot, covariant=True)
ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True)
class Check(Protocol[ContextT_co]):
predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]]
def __call__(self, coro_or_commands: T) -> T:
...
# This is merely a tag type to avoid circular import issues. # This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution. # Yes, this is a terrible solution but ultimately it is the only solution.

10
discord/ext/commands/bot.py

@ -73,7 +73,7 @@ if TYPE_CHECKING:
from ._types import ( from ._types import (
_Bot, _Bot,
BotT, BotT,
Check, UserCheck,
CoroFunc, CoroFunc,
ContextT, ContextT,
MaybeAwaitableFunc, MaybeAwaitableFunc,
@ -173,8 +173,8 @@ class BotBase(GroupMixin[None]):
self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore
self.__cogs: Dict[str, Cog] = {} self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {} self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = [] self._checks: List[UserCheck] = []
self._check_once: List[Check] = [] self._check_once: List[UserCheck] = []
self._before_invoke: Optional[CoroFunc] = None self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand] = None self._help_command: Optional[HelpCommand] = None
@ -359,7 +359,7 @@ class BotBase(GroupMixin[None]):
self.add_check(func) # type: ignore self.add_check(func) # type: ignore
return func return func
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot. """Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check` This is the non-decorator interface to :meth:`.check`
@ -383,7 +383,7 @@ class BotBase(GroupMixin[None]):
else: else:
self._checks.append(func) self._checks.append(func)
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None: def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot. """Removes a global check from the bot.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception

32
discord/ext/commands/core.py

@ -60,7 +60,7 @@ if TYPE_CHECKING:
from discord.message import Message from discord.message import Message
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
__all__ = ( __all__ = (
@ -378,7 +378,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError: except AttributeError:
checks = kwargs.get('checks', []) checks = kwargs.get('checks', [])
self.checks: List[Check[ContextT]] = checks self.checks: List[UserCheck[ContextT]] = checks
try: try:
cooldown = func.__commands_cooldown__ cooldown = func.__commands_cooldown__
@ -458,7 +458,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check[ContextT], /) -> None: def add_check(self, func: UserCheck[ContextT], /) -> None:
"""Adds a check to the command. """Adds a check to the command.
This is the non-decorator interface to :func:`.check`. This is the non-decorator interface to :func:`.check`.
@ -477,7 +477,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func) self.checks.append(func)
def remove_check(self, func: Check[ContextT], /) -> None: def remove_check(self, func: UserCheck[ContextT], /) -> None:
"""Removes a check from the command. """Removes a check from the command.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -1745,7 +1745,7 @@ def group(
return command(name=name, cls=cls, **attrs) return command(name=name, cls=cls, **attrs)
def check(predicate: Check[ContextT], /) -> Callable[[T], T]: def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
r"""A decorator that adds a check to the :class:`.Command` or its r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`. subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1844,7 +1844,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
return decorator # type: ignore return decorator # type: ignore
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]: def check_any(*checks: Check[ContextT]) -> Check[ContextT]:
r"""A :func:`check` that is added that checks if any of the checks passed r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR. will pass, i.e. using logical OR.
@ -1910,10 +1910,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
# if we're here, all checks failed # if we're here, all checks failed
raise CheckAnyFailure(unwrapped, errors) raise CheckAnyFailure(unwrapped, errors)
return check(predicate) return check(predicate) # type: ignore
def has_role(item: Union[int, str], /) -> Callable[[T], T]: def has_role(item: Union[int, str], /) -> Check[Any]:
"""A :func:`.check` that is added that checks if the member invoking the """A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified. command has the role specified via the name or ID specified.
@ -2066,7 +2066,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def has_permissions(**perms: bool) -> Callable[[T], T]: def has_permissions(**perms: bool) -> Check[Any]:
"""A :func:`.check` that is added that checks if the member has all of """A :func:`.check` that is added that checks if the member has all of
the permissions necessary. the permissions necessary.
@ -2114,7 +2114,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]: def bot_has_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has """Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed. the permissions listed.
@ -2141,7 +2141,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]: def has_guild_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_permissions`, but operates on guild wide """Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions. permissions instead of the current channel permissions.
@ -2170,7 +2170,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: def bot_has_guild_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot """Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions. members guild permissions.
@ -2196,7 +2196,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def dm_only() -> Callable[[T], T]: def dm_only() -> Check[Any]:
"""A :func:`.check` that indicates this command must only be used in a """A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when DM context. Only private messages are allowed when
using the command. using the command.
@ -2215,7 +2215,7 @@ def dm_only() -> Callable[[T], T]:
return check(predicate) return check(predicate)
def guild_only() -> Callable[[T], T]: def guild_only() -> Check[Any]:
"""A :func:`.check` that indicates this command must only be used in a """A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when guild context only. Basically, no private messages are allowed when
using the command. using the command.
@ -2232,7 +2232,7 @@ def guild_only() -> Callable[[T], T]:
return check(predicate) return check(predicate)
def is_owner() -> Callable[[T], T]: def is_owner() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is the """A :func:`.check` that checks if the person invoking this command is the
owner of the bot. owner of the bot.
@ -2250,7 +2250,7 @@ def is_owner() -> Callable[[T], T]:
return check(predicate) return check(predicate)
def is_nsfw() -> Callable[[T], T]: def is_nsfw() -> Check[Any]:
"""A :func:`.check` that checks if the channel is a NSFW channel. """A :func:`.check` that checks if the channel is a NSFW channel.
This check raises a special exception, :exc:`.NSFWChannelRequired` This check raises a special exception, :exc:`.NSFWChannelRequired`

6
discord/ext/commands/help.py

@ -60,7 +60,7 @@ if TYPE_CHECKING:
from .parameters import Parameter from .parameters import Parameter
from ._types import ( from ._types import (
Check, UserCheck,
ContextT, ContextT,
BotT, BotT,
_Bot, _Bot,
@ -378,7 +378,7 @@ class HelpCommand:
bot.remove_command(self._command_impl.name) bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog() self._command_impl._eject_cog()
def add_check(self, func: Check[ContextT], /) -> None: def add_check(self, func: UserCheck[ContextT], /) -> None:
""" """
Adds a check to the help command. Adds a check to the help command.
@ -396,7 +396,7 @@ class HelpCommand:
self._command_impl.add_check(func) self._command_impl.add_check(func)
def remove_check(self, func: Check[ContextT], /) -> None: def remove_check(self, func: UserCheck[ContextT], /) -> None:
""" """
Removes a check from the help command. Removes a check from the help command.

Loading…
Cancel
Save