From bea6b815e2c4b6807ba2c07a90d788736df8f98b Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 19 Mar 2022 01:01:30 -0400 Subject: [PATCH] Add support for app command checks This does not include any built-in checks due to design considerations. --- discord/app_commands/commands.py | 215 ++++++++++++++++++++++++++++++- discord/app_commands/errors.py | 12 ++ docs/interactions/api.rst | 7 + 3 files changed, 232 insertions(+), 2 deletions(-) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index cc1b58759..e524379ae 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -49,11 +49,11 @@ import re from ..enums import AppCommandOptionType, AppCommandType from .models import Choice from .transformers import annotation_to_parameter, CommandParameter, NoneType -from .errors import AppCommandError, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered +from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered from ..message import Message from ..user import User from ..member import Member -from ..utils import resolve_annotation, MISSING, is_inside_class +from ..utils import resolve_annotation, MISSING, is_inside_class, maybe_coroutine, async_all if TYPE_CHECKING: from typing_extensions import ParamSpec, Concatenate @@ -74,6 +74,7 @@ __all__ = ( 'context_menu', 'command', 'describe', + 'check', 'choices', 'autocomplete', 'guilds', @@ -91,6 +92,7 @@ Error = Union[ Callable[[GroupT, 'Interaction', AppCommandError], Coro[Any]], Callable[['Interaction', AppCommandError], Coro[Any]], ] +Check = Callable[['Interaction'], Union[bool, Coro[bool]]] if TYPE_CHECKING: @@ -121,6 +123,7 @@ else: AutocompleteCallback = Callable[..., Coro[T]] +CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', CommandCallback, ContextMenuCallback] VALID_SLASH_COMMAND_NAME = re.compile(r'^[\w-]{1,32}$') VALID_CONTEXT_MENU_NAME = re.compile(r'^[\w\s-]{1,32}$') CAMEL_CASE_REGEX = re.compile(r'(? T: + if not await self._check_can_run(interaction): + raise CheckFailure(f'The check functions for command {self.name!r} failed.') + values = namespace.__dict__ for name, param in self._params.items(): try: @@ -515,6 +529,34 @@ class Command(Generic[GroupT, P, T]): parent = self.parent return parent.parent or parent + async def _check_can_run(self, interaction: Interaction) -> bool: + if self.parent is not None and self.parent is not self.binding: + # For commands with a parent which isn't the binding, i.e. + # + # + # + # The parent check needs to be called first + if not await maybe_coroutine(self.parent.interaction_check, interaction): + return False + + if self.binding is not None: + try: + # Type checker does not like runtime attribute retrieval + check: Check = self.binding.interaction_check # type: ignore + except AttributeError: + pass + else: + ret = await maybe_coroutine(check, interaction) + if not ret: + return False + + predicates = self.checks + if not predicates: + return True + + # Type checker does not understand negative narrowing cases like this function + return await async_all(f(interaction) for f in predicates) # type: ignore + def error(self, coro: Error[GroupT]) -> Error[GroupT]: """A decorator that registers a coroutine as a local error handler. @@ -611,6 +653,36 @@ class Command(Generic[GroupT, P, T]): return decorator + def add_check(self, func: Check, /) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`check`. + + Parameters + ----------- + func + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check, /) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + Parameters + ----------- + func + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + class ContextMenu: """A class that implements a context menu application command. @@ -629,6 +701,12 @@ class ContextMenu: The name of the context menu. type: :class:`.AppCommandType` The type of context menu application command. + checks + A list of predicates that take a :class:`~discord.Interaction` parameter + to indicate whether the command callback should be executed. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`AppCommandError` should be used. If all the checks fail without + propagating an exception, :exc:`CheckFailure` is raised. """ def __init__( @@ -649,6 +727,7 @@ class ContextMenu: self._annotation = annotation self.module: Optional[str] = callback.__module__ self._guild_ids = guild_ids + self.checks: List[Check] = getattr(callback, '__discord_app_commands_checks__', []) @property def callback(self) -> ContextMenuCallback: @@ -667,6 +746,7 @@ class ContextMenu: self._annotation = annotation self.module = callback.__module__ self._guild_ids = None + self.checks = getattr(callback, '__discord_app_commands_checks__', []) return self def to_dict(self) -> Dict[str, Any]: @@ -675,14 +755,55 @@ class ContextMenu: 'type': self.type.value, } + async def _check_can_run(self, interaction: Interaction) -> bool: + predicates = self.checks + if not predicates: + return True + + # Type checker does not understand negative narrowing cases like this function + return await async_all(f(interaction) for f in predicates) # type: ignore + async def _invoke(self, interaction: Interaction, arg: Any): try: + if not await self._check_can_run(interaction): + raise CheckFailure(f'The check functions for context menu {self.name!r} failed.') + await self._callback(interaction, arg) except AppCommandError: raise except Exception as e: raise CommandInvokeError(self, e) from e + def add_check(self, func: Check, /) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`check`. + + Parameters + ----------- + func + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check, /) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + Parameters + ----------- + func + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + class Group: """A class that implements an application command group. @@ -857,6 +978,37 @@ class Group: pass + async def interaction_check(self, interaction: Interaction) -> bool: + """|coro| + + A callback that is called when an interaction happens within the group + that checks whether a command inside the group should be executed. + + This is useful to override if, for example, you want to ensure that the + interaction author is a given user. + + The default implementation of this returns ``True``. + + .. note:: + + If an exception occurs within the body then the check + is considered a failure and error handlers such as + :meth:`on_error` is called. See :exc:`AppCommandError` + for more information. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that occurred. + + Returns + --------- + :class:`bool` + Whether the view children's callbacks should be called. + """ + + return True + def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None: """Adds a command or group to this group's internal list of commands. @@ -1260,3 +1412,62 @@ def guilds(*guild_ids: Union[Snowflake, int]) -> Callable[[T], T]: return inner return decorator + + +def check(predicate: Check) -> Callable[[T], T]: + r"""A decorator that adds a check to an application command. + + These checks should be predicates that take in a single parameter taking + a :class:`~discord.Interaction`. If the check returns a ``False``\-like value then + during invocation a :exc:`CheckFailure` exception is raised and sent to + the appropriate error handlers. + + These checks can be either a coroutine or not. + + Examples + --------- + + Creating a basic check to see if the command invoker is you. + + .. code-block:: python3 + + def check_if_it_is_me(interaction: discord.Interaction) -> bool: + return interaction.user.id == 85309593344815104 + + @tree.command() + @app_commands.check(check_if_it_is_me) + async def only_for_me(interaction: discord.Interaction): + await interaction.response.send_message('I know you!', ephemeral=True) + + Transforming common checks into its own decorator: + + .. code-block:: python3 + + def is_me(): + def predicate(interaction: discord.Interaction) -> bool: + return interaction.user.id == 85309593344815104 + return commands.check(predicate) + + @tree.command() + @is_me() + async def only_me(interaction: discord.Interaction): + await interaction.response.send_message('Only you!') + + Parameters + ----------- + predicate: Callable[[:class:`~discord.Interaction`], :class:`bool`] + The predicate to check if the command should be invoked. + """ + + def decorator(func: CheckInputParameter) -> CheckInputParameter: + if isinstance(func, (Command, ContextMenu)): + func.checks.append(predicate) + else: + if not hasattr(func, '__discord_app_commands_checks__'): + func.__discord_app_commands_checks__ = [] + + func.__discord_app_commands_checks__.append(predicate) + + return func + + return decorator # type: ignore diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py index 73cda13ab..4bcbcc98a 100644 --- a/discord/app_commands/errors.py +++ b/discord/app_commands/errors.py @@ -34,6 +34,7 @@ __all__ = ( 'AppCommandError', 'CommandInvokeError', 'TransformerError', + 'CheckFailure', 'CommandAlreadyRegistered', 'CommandSignatureMismatch', 'CommandNotFound', @@ -128,6 +129,17 @@ class TransformerError(AppCommandError): super().__init__(f'Failed to convert {value} to {result_type!s}') +class CheckFailure(AppCommandError): + """An exception raised when check predicates in a command have failed. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + """ + + pass + + class CommandAlreadyRegistered(AppCommandError): """An exception raised when a command is already registered. diff --git a/docs/interactions/api.rst b/docs/interactions/api.rst index 63d934e24..289245061 100644 --- a/docs/interactions/api.rst +++ b/docs/interactions/api.rst @@ -465,6 +465,9 @@ Decorators .. autofunction:: discord.app_commands.choices :decorator: +.. autofunction:: discord.app_commands.check + :decorator: + .. autofunction:: discord.app_commands.autocomplete :decorator: @@ -518,6 +521,9 @@ Exceptions .. autoexception:: discord.app_commands.TransformerError :members: +.. autoexception:: discord.app_commands.CheckFailure + :members: + .. autoexception:: discord.app_commands.CommandAlreadyRegistered :members: @@ -536,6 +542,7 @@ Exception Hierarchy - :exc:`~discord.app_commands.AppCommandError` - :exc:`~discord.app_commands.CommandInvokeError` - :exc:`~discord.app_commands.TransformerError` + - :exc:`~discord.app_commands.CheckFailure` - :exc:`~discord.app_commands.CommandAlreadyRegistered` - :exc:`~discord.app_commands.CommandSignatureMismatch` - :exc:`~discord.app_commands.CommandNotFound`