Browse Source

Add support for app command checks

This does not include any built-in checks due to design considerations.
pull/7723/head
Rapptz 3 years ago
parent
commit
bea6b815e2
  1. 215
      discord/app_commands/commands.py
  2. 12
      discord/app_commands/errors.py
  3. 7
      docs/interactions/api.rst

215
discord/app_commands/commands.py

@ -49,11 +49,11 @@ import re
from ..enums import AppCommandOptionType, AppCommandType from ..enums import AppCommandOptionType, AppCommandType
from .models import Choice from .models import Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
from ..message import Message from ..message import Message
from ..user import User from ..user import User
from ..member import Member from ..member import Member
from ..utils import resolve_annotation, MISSING, is_inside_class from ..utils import resolve_annotation, MISSING, is_inside_class, maybe_coroutine, async_all
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec, Concatenate from typing_extensions import ParamSpec, Concatenate
@ -74,6 +74,7 @@ __all__ = (
'context_menu', 'context_menu',
'command', 'command',
'describe', 'describe',
'check',
'choices', 'choices',
'autocomplete', 'autocomplete',
'guilds', 'guilds',
@ -91,6 +92,7 @@ Error = Union[
Callable[[GroupT, 'Interaction', AppCommandError], Coro[Any]], Callable[[GroupT, 'Interaction', AppCommandError], Coro[Any]],
Callable[['Interaction', AppCommandError], Coro[Any]], Callable[['Interaction', AppCommandError], Coro[Any]],
] ]
Check = Callable[['Interaction'], Union[bool, Coro[bool]]]
if TYPE_CHECKING: if TYPE_CHECKING:
@ -121,6 +123,7 @@ else:
AutocompleteCallback = Callable[..., Coro[T]] AutocompleteCallback = Callable[..., Coro[T]]
CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', CommandCallback, ContextMenuCallback]
VALID_SLASH_COMMAND_NAME = re.compile(r'^[\w-]{1,32}$') VALID_SLASH_COMMAND_NAME = re.compile(r'^[\w-]{1,32}$')
VALID_CONTEXT_MENU_NAME = re.compile(r'^[\w\s-]{1,32}$') VALID_CONTEXT_MENU_NAME = re.compile(r'^[\w\s-]{1,32}$')
CAMEL_CASE_REGEX = re.compile(r'(?<!^)(?=[A-Z])') CAMEL_CASE_REGEX = re.compile(r'(?<!^)(?=[A-Z])')
@ -356,6 +359,12 @@ class Command(Generic[GroupT, P, T]):
description: :class:`str` description: :class:`str`
The description of the application command. This shows up in the UI to describe The description of the application command. This shows up in the UI to describe
the application command. the application command.
checks
A list of predicates that take a :class:`~discord.Interaction` parameter
to indicate whether the command callback should be executed. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`AppCommandError` should be used. If all the checks fail without
propagating an exception, :exc:`CheckFailure` is raised.
parent: Optional[:class:`Group`] parent: Optional[:class:`Group`]
The parent application command. ``None`` if there isn't one. The parent application command. ``None`` if there isn't one.
""" """
@ -386,6 +395,7 @@ class Command(Generic[GroupT, P, T]):
pass pass
self._params: Dict[str, CommandParameter] = _extract_parameters_from_callback(callback, callback.__globals__) self._params: Dict[str, CommandParameter] = _extract_parameters_from_callback(callback, callback.__globals__)
self.checks: List[Check] = getattr(callback, '__discord_app_commands_checks__', [])
self._guild_ids: Optional[List[int]] = guild_ids or getattr( self._guild_ids: Optional[List[int]] = guild_ids or getattr(
callback, '__discord_app_commands_default_guilds__', None callback, '__discord_app_commands_default_guilds__', None
) )
@ -406,6 +416,7 @@ class Command(Generic[GroupT, P, T]):
copy = cls.__new__(cls) copy = cls.__new__(cls)
copy.name = self.name copy.name = self.name
copy._guild_ids = self._guild_ids copy._guild_ids = self._guild_ids
copy.checks = self.checks
copy.description = self.description copy.description = self.description
copy._attr = self._attr copy._attr = self._attr
copy._callback = self._callback copy._callback = self._callback
@ -443,6 +454,9 @@ class Command(Generic[GroupT, P, T]):
await parent.parent.on_error(interaction, self, error) await parent.parent.on_error(interaction, self, error)
async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T: async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T:
if not await self._check_can_run(interaction):
raise CheckFailure(f'The check functions for command {self.name!r} failed.')
values = namespace.__dict__ values = namespace.__dict__
for name, param in self._params.items(): for name, param in self._params.items():
try: try:
@ -515,6 +529,34 @@ class Command(Generic[GroupT, P, T]):
parent = self.parent parent = self.parent
return parent.parent or parent return parent.parent or parent
async def _check_can_run(self, interaction: Interaction) -> bool:
if self.parent is not None and self.parent is not self.binding:
# For commands with a parent which isn't the binding, i.e.
# <binding>
# <parent>
# <command>
# The parent check needs to be called first
if not await maybe_coroutine(self.parent.interaction_check, interaction):
return False
if self.binding is not None:
try:
# Type checker does not like runtime attribute retrieval
check: Check = self.binding.interaction_check # type: ignore
except AttributeError:
pass
else:
ret = await maybe_coroutine(check, interaction)
if not ret:
return False
predicates = self.checks
if not predicates:
return True
# Type checker does not understand negative narrowing cases like this function
return await async_all(f(interaction) for f in predicates) # type: ignore
def error(self, coro: Error[GroupT]) -> Error[GroupT]: def error(self, coro: Error[GroupT]) -> Error[GroupT]:
"""A decorator that registers a coroutine as a local error handler. """A decorator that registers a coroutine as a local error handler.
@ -611,6 +653,36 @@ class Command(Generic[GroupT, P, T]):
return decorator return decorator
def add_check(self, func: Check, /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`check`.
Parameters
-----------
func
The function that will be used as a check.
"""
self.checks.append(func)
def remove_check(self, func: Check, /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
if the function is not in the command's checks.
Parameters
-----------
func
The function to remove from the checks.
"""
try:
self.checks.remove(func)
except ValueError:
pass
class ContextMenu: class ContextMenu:
"""A class that implements a context menu application command. """A class that implements a context menu application command.
@ -629,6 +701,12 @@ class ContextMenu:
The name of the context menu. The name of the context menu.
type: :class:`.AppCommandType` type: :class:`.AppCommandType`
The type of context menu application command. The type of context menu application command.
checks
A list of predicates that take a :class:`~discord.Interaction` parameter
to indicate whether the command callback should be executed. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`AppCommandError` should be used. If all the checks fail without
propagating an exception, :exc:`CheckFailure` is raised.
""" """
def __init__( def __init__(
@ -649,6 +727,7 @@ class ContextMenu:
self._annotation = annotation self._annotation = annotation
self.module: Optional[str] = callback.__module__ self.module: Optional[str] = callback.__module__
self._guild_ids = guild_ids self._guild_ids = guild_ids
self.checks: List[Check] = getattr(callback, '__discord_app_commands_checks__', [])
@property @property
def callback(self) -> ContextMenuCallback: def callback(self) -> ContextMenuCallback:
@ -667,6 +746,7 @@ class ContextMenu:
self._annotation = annotation self._annotation = annotation
self.module = callback.__module__ self.module = callback.__module__
self._guild_ids = None self._guild_ids = None
self.checks = getattr(callback, '__discord_app_commands_checks__', [])
return self return self
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@ -675,14 +755,55 @@ class ContextMenu:
'type': self.type.value, 'type': self.type.value,
} }
async def _check_can_run(self, interaction: Interaction) -> bool:
predicates = self.checks
if not predicates:
return True
# Type checker does not understand negative narrowing cases like this function
return await async_all(f(interaction) for f in predicates) # type: ignore
async def _invoke(self, interaction: Interaction, arg: Any): async def _invoke(self, interaction: Interaction, arg: Any):
try: try:
if not await self._check_can_run(interaction):
raise CheckFailure(f'The check functions for context menu {self.name!r} failed.')
await self._callback(interaction, arg) await self._callback(interaction, arg)
except AppCommandError: except AppCommandError:
raise raise
except Exception as e: except Exception as e:
raise CommandInvokeError(self, e) from e raise CommandInvokeError(self, e) from e
def add_check(self, func: Check, /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`check`.
Parameters
-----------
func
The function that will be used as a check.
"""
self.checks.append(func)
def remove_check(self, func: Check, /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
if the function is not in the command's checks.
Parameters
-----------
func
The function to remove from the checks.
"""
try:
self.checks.remove(func)
except ValueError:
pass
class Group: class Group:
"""A class that implements an application command group. """A class that implements an application command group.
@ -857,6 +978,37 @@ class Group:
pass pass
async def interaction_check(self, interaction: Interaction) -> bool:
"""|coro|
A callback that is called when an interaction happens within the group
that checks whether a command inside the group should be executed.
This is useful to override if, for example, you want to ensure that the
interaction author is a given user.
The default implementation of this returns ``True``.
.. note::
If an exception occurs within the body then the check
is considered a failure and error handlers such as
:meth:`on_error` is called. See :exc:`AppCommandError`
for more information.
Parameters
-----------
interaction: :class:`~discord.Interaction`
The interaction that occurred.
Returns
---------
:class:`bool`
Whether the view children's callbacks should be called.
"""
return True
def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None: def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None:
"""Adds a command or group to this group's internal list of commands. """Adds a command or group to this group's internal list of commands.
@ -1260,3 +1412,62 @@ def guilds(*guild_ids: Union[Snowflake, int]) -> Callable[[T], T]:
return inner return inner
return decorator return decorator
def check(predicate: Check) -> Callable[[T], T]:
r"""A decorator that adds a check to an application command.
These checks should be predicates that take in a single parameter taking
a :class:`~discord.Interaction`. If the check returns a ``False``\-like value then
during invocation a :exc:`CheckFailure` exception is raised and sent to
the appropriate error handlers.
These checks can be either a coroutine or not.
Examples
---------
Creating a basic check to see if the command invoker is you.
.. code-block:: python3
def check_if_it_is_me(interaction: discord.Interaction) -> bool:
return interaction.user.id == 85309593344815104
@tree.command()
@app_commands.check(check_if_it_is_me)
async def only_for_me(interaction: discord.Interaction):
await interaction.response.send_message('I know you!', ephemeral=True)
Transforming common checks into its own decorator:
.. code-block:: python3
def is_me():
def predicate(interaction: discord.Interaction) -> bool:
return interaction.user.id == 85309593344815104
return commands.check(predicate)
@tree.command()
@is_me()
async def only_me(interaction: discord.Interaction):
await interaction.response.send_message('Only you!')
Parameters
-----------
predicate: Callable[[:class:`~discord.Interaction`], :class:`bool`]
The predicate to check if the command should be invoked.
"""
def decorator(func: CheckInputParameter) -> CheckInputParameter:
if isinstance(func, (Command, ContextMenu)):
func.checks.append(predicate)
else:
if not hasattr(func, '__discord_app_commands_checks__'):
func.__discord_app_commands_checks__ = []
func.__discord_app_commands_checks__.append(predicate)
return func
return decorator # type: ignore

12
discord/app_commands/errors.py

@ -34,6 +34,7 @@ __all__ = (
'AppCommandError', 'AppCommandError',
'CommandInvokeError', 'CommandInvokeError',
'TransformerError', 'TransformerError',
'CheckFailure',
'CommandAlreadyRegistered', 'CommandAlreadyRegistered',
'CommandSignatureMismatch', 'CommandSignatureMismatch',
'CommandNotFound', 'CommandNotFound',
@ -128,6 +129,17 @@ class TransformerError(AppCommandError):
super().__init__(f'Failed to convert {value} to {result_type!s}') super().__init__(f'Failed to convert {value} to {result_type!s}')
class CheckFailure(AppCommandError):
"""An exception raised when check predicates in a command have failed.
This inherits from :exc:`~discord.app_commands.AppCommandError`.
.. versionadded:: 2.0
"""
pass
class CommandAlreadyRegistered(AppCommandError): class CommandAlreadyRegistered(AppCommandError):
"""An exception raised when a command is already registered. """An exception raised when a command is already registered.

7
docs/interactions/api.rst

@ -465,6 +465,9 @@ Decorators
.. autofunction:: discord.app_commands.choices .. autofunction:: discord.app_commands.choices
:decorator: :decorator:
.. autofunction:: discord.app_commands.check
:decorator:
.. autofunction:: discord.app_commands.autocomplete .. autofunction:: discord.app_commands.autocomplete
:decorator: :decorator:
@ -518,6 +521,9 @@ Exceptions
.. autoexception:: discord.app_commands.TransformerError .. autoexception:: discord.app_commands.TransformerError
:members: :members:
.. autoexception:: discord.app_commands.CheckFailure
:members:
.. autoexception:: discord.app_commands.CommandAlreadyRegistered .. autoexception:: discord.app_commands.CommandAlreadyRegistered
:members: :members:
@ -536,6 +542,7 @@ Exception Hierarchy
- :exc:`~discord.app_commands.AppCommandError` - :exc:`~discord.app_commands.AppCommandError`
- :exc:`~discord.app_commands.CommandInvokeError` - :exc:`~discord.app_commands.CommandInvokeError`
- :exc:`~discord.app_commands.TransformerError` - :exc:`~discord.app_commands.TransformerError`
- :exc:`~discord.app_commands.CheckFailure`
- :exc:`~discord.app_commands.CommandAlreadyRegistered` - :exc:`~discord.app_commands.CommandAlreadyRegistered`
- :exc:`~discord.app_commands.CommandSignatureMismatch` - :exc:`~discord.app_commands.CommandSignatureMismatch`
- :exc:`~discord.app_commands.CommandNotFound` - :exc:`~discord.app_commands.CommandNotFound`

Loading…
Cancel
Save