Browse Source

[commands] Fix typing of check/check_any

This changes the type information of check decorators to return a
protocol representing that the decorator leaves the underlying object
unchanged while having a .predicate attribute.

resolves #7949
pull/7987/head
Michael H 3 years ago
committed by GitHub
parent
commit
d0667d08e3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      discord/ext/commands/_types.py
  2. 10
      discord/ext/commands/bot.py
  3. 32
      discord/ext/commands/core.py
  4. 6
      discord/ext/commands/help.py

14
discord/ext/commands/_types.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple, Optional
from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional
T = TypeVar('T')
@ -49,13 +49,23 @@ MaybeCoro = Union[T, Coro[T]]
MaybeAwaitable = Union[T, Awaitable[T]]
CogT = TypeVar('CogT', bound='Optional[Cog]')
Check = Callable[["ContextT"], MaybeCoro[bool]]
UserCheck = Callable[["ContextT"], MaybeCoro[bool]]
Hook = Union[Callable[["CogT", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
Error = Union[Callable[["CogT", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
ContextT = TypeVar('ContextT', bound='Context[Any]')
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True)
class Check(Protocol[ContextT_co]):
predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]]
def __call__(self, coro_or_commands: T) -> T:
...
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.

10
discord/ext/commands/bot.py

@ -73,7 +73,7 @@ if TYPE_CHECKING:
from ._types import (
_Bot,
BotT,
Check,
UserCheck,
CoroFunc,
ContextT,
MaybeAwaitableFunc,
@ -173,8 +173,8 @@ class BotBase(GroupMixin[None]):
self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once: List[Check] = []
self._checks: List[UserCheck] = []
self._check_once: List[UserCheck] = []
self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand] = None
@ -359,7 +359,7 @@ class BotBase(GroupMixin[None]):
self.add_check(func) # type: ignore
return func
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -383,7 +383,7 @@ class BotBase(GroupMixin[None]):
else:
self._checks.append(func)
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception

32
discord/ext/commands/core.py

@ -60,7 +60,7 @@ if TYPE_CHECKING:
from discord.message import Message
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook
from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck
__all__ = (
@ -378,7 +378,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError:
checks = kwargs.get('checks', [])
self.checks: List[Check[ContextT]] = checks
self.checks: List[UserCheck[ContextT]] = checks
try:
cooldown = func.__commands_cooldown__
@ -458,7 +458,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check[ContextT], /) -> None:
def add_check(self, func: UserCheck[ContextT], /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`.check`.
@ -477,7 +477,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func)
def remove_check(self, func: Check[ContextT], /) -> None:
def remove_check(self, func: UserCheck[ContextT], /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
@ -1745,7 +1745,7 @@ def group(
return command(name=name, cls=cls, **attrs)
def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1844,7 +1844,7 @@ def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
return decorator # type: ignore
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
def check_any(*checks: Check[ContextT]) -> Check[ContextT]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@ -1910,10 +1910,10 @@ def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
# if we're here, all checks failed
raise CheckAnyFailure(unwrapped, errors)
return check(predicate)
return check(predicate) # type: ignore
def has_role(item: Union[int, str], /) -> Callable[[T], T]:
def has_role(item: Union[int, str], /) -> Check[Any]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
@ -2066,7 +2066,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
return check(predicate)
def has_permissions(**perms: bool) -> Callable[[T], T]:
def has_permissions(**perms: bool) -> Check[Any]:
"""A :func:`.check` that is added that checks if the member has all of
the permissions necessary.
@ -2114,7 +2114,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
def bot_has_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed.
@ -2141,7 +2141,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
def has_guild_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions.
@ -2170,7 +2170,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
def bot_has_guild_permissions(**perms: bool) -> Check[Any]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions.
@ -2196,7 +2196,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
return check(predicate)
def dm_only() -> Callable[[T], T]:
def dm_only() -> Check[Any]:
"""A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when
using the command.
@ -2215,7 +2215,7 @@ def dm_only() -> Callable[[T], T]:
return check(predicate)
def guild_only() -> Callable[[T], T]:
def guild_only() -> Check[Any]:
"""A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when
using the command.
@ -2232,7 +2232,7 @@ def guild_only() -> Callable[[T], T]:
return check(predicate)
def is_owner() -> Callable[[T], T]:
def is_owner() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is the
owner of the bot.
@ -2250,7 +2250,7 @@ def is_owner() -> Callable[[T], T]:
return check(predicate)
def is_nsfw() -> Callable[[T], T]:
def is_nsfw() -> Check[Any]:
"""A :func:`.check` that checks if the channel is a NSFW channel.
This check raises a special exception, :exc:`.NSFWChannelRequired`

6
discord/ext/commands/help.py

@ -60,7 +60,7 @@ if TYPE_CHECKING:
from .parameters import Parameter
from ._types import (
Check,
UserCheck,
ContextT,
BotT,
_Bot,
@ -378,7 +378,7 @@ class HelpCommand:
bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog()
def add_check(self, func: Check[ContextT], /) -> None:
def add_check(self, func: UserCheck[ContextT], /) -> None:
"""
Adds a check to the help command.
@ -396,7 +396,7 @@ class HelpCommand:
self._command_impl.add_check(func)
def remove_check(self, func: Check[ContextT], /) -> None:
def remove_check(self, func: UserCheck[ContextT], /) -> None:
"""
Removes a check from the help command.

Loading…
Cancel
Save