Browse Source

[commands] Fix typing of check/check_any

This changes the type information of check decorators to return a
protocol representing that the decorator leaves the underlying object
unchanged while having a .predicate attribute.

resolves #7949
pull/10109/head
Michael H 3 years ago
committed by dolfies
parent
commit
3a1cda5ff2
  1. 14
      discord/ext/commands/_types.py
  2. 10
      discord/ext/commands/bot.py
  3. 32
      discord/ext/commands/core.py
  4. 6
      discord/ext/commands/help.py

14
discord/ext/commands/_types.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional
T = TypeVar('T')
@ -51,13 +51,23 @@ MaybeCoro = Union[T, Coro[T]]
MaybeAwaitable = Union[T, Awaitable[T]]
CogT = TypeVar('CogT', bound='Optional[Cog]')
Check = Callable[["ContextT"], MaybeCoro[bool]]
UserCheck = Callable[["ContextT"], MaybeCoro[bool]]
Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
ContextT = TypeVar('ContextT', bound='Context[Any]')
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True)
class Check(Protocol[ContextT_co]):
predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]]
def __call__(self, coro_or_commands: T) -> T:
...
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.

10
discord/ext/commands/bot.py

@ -69,7 +69,7 @@ if TYPE_CHECKING:
from ._types import (
_Bot,
BotT,
Check,
UserCheck,
CoroFunc,
ContextT,
MaybeAwaitableFunc,
@ -161,8 +161,8 @@ class BotBase(GroupMixin[None]):
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once: List[Check] = []
self._checks: List[UserCheck] = []
self._check_once: List[UserCheck] = []
self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand] = None
@ -282,7 +282,7 @@ class BotBase(GroupMixin[None]):
self.add_check(func) # type: ignore
return func
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -306,7 +306,7 @@ class BotBase(GroupMixin[None]):
else:
self._checks.append(func)
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception

32
discord/ext/commands/core.py

@ -60,7 +60,7 @@ if TYPE_CHECKING:
from discord.message import Message
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
__all__ = (
@ -375,7 +375,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError:
checks = kwargs.get('checks', [])
self.checks: List[Check[ContextT]] = checks
self.checks: List[UserCheck[ContextT]] = checks
try:
cooldown = func.__commands_cooldown__
@ -447,7 +447,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check[ContextT], /) -> None:
def add_check(self, func: UserCheck[ContextT], /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`.check`.
@ -466,7 +466,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func)
def remove_check(self, func: Check[ContextT], /) -> None:
def remove_check(self, func: UserCheck[ContextT], /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
@ -1734,7 +1734,7 @@ def group(
return command(name=name, cls=cls, **attrs)
def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1833,7 +1833,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
return decorator # type: ignore
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
def check_any(*checks: Check[ContextT]) -> Check[ContextT]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@ -1899,10 +1899,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
# if we're here, all checks failed
raise CheckAnyFailure(unwrapped, errors)
return check(predicate)
return check(predicate) # type: ignore
def has_role(item: Union[int, str], /) -> Callable[[T], T]:
def has_role(item: Union[int, str], /) -> Check[Any]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
@ -2055,7 +2055,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
return check(predicate)
def has_permissions(**perms: bool) -> Callable[[T], T]:
def has_permissions(**perms: bool) -> Check[Any]:
"""A :func:`.check` that is added that checks if the member has all of
the permissions necessary.
@ -2103,7 +2103,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
def bot_has_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed.
@ -2130,7 +2130,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
def has_guild_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions.
@ -2159,7 +2159,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
def bot_has_guild_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions.
@ -2185,7 +2185,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def dm_only() -> Callable[[T], T]:
def dm_only() -> Check[Any]:
"""A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when
using the command.
@ -2204,7 +2204,7 @@ def dm_only() -> Callable[[T], T]:
return check(predicate)
def guild_only() -> Callable[[T], T]:
def guild_only() -> Check[Any]:
"""A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when
using the command.
@ -2221,7 +2221,7 @@ def guild_only() -> Callable[[T], T]:
return check(predicate)
def is_owner() -> Callable[[T], T]:
def is_owner() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is the
owner of the bot.
@ -2239,7 +2239,7 @@ def is_owner() -> Callable[[T], T]:
return check(predicate)
def is_nsfw() -> Callable[[T], T]:
def is_nsfw() -> Check[Any]:
"""A :func:`.check` that checks if the channel is a NSFW channel.
This check raises a special exception, :exc:`.NSFWChannelRequired`

6
discord/ext/commands/help.py

@ -61,7 +61,7 @@ if TYPE_CHECKING:
from .parameters import Parameter
from ._types import (
Check,
UserCheck,
ContextT,
BotT,
_Bot,
@ -377,7 +377,7 @@ class HelpCommand:
bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog()
def add_check(self, func: Check[ContextT], /) -> None:
def add_check(self, func: UserCheck[ContextT], /) -> None:
"""
Adds a check to the help command.
@ -395,7 +395,7 @@ class HelpCommand:
self._command_impl.add_check(func)
def remove_check(self, func: Check[ContextT], /) -> None:
def remove_check(self, func: UserCheck[ContextT], /) -> None:
"""
Removes a check from the help command.

Loading…
Cancel
Save