From 0546343bcb138b71ca1f16a5ae10893efeac651c Mon Sep 17 00:00:00 2001 From: Stocker <44980366+StockerMC@users.noreply.github.com> Date: Sun, 17 Jul 2022 23:45:19 -0400 Subject: [PATCH] [commands] Add cog-level app command error special method --- discord/app_commands/commands.py | 6 ++++- discord/app_commands/tree.py | 13 +++++++--- discord/ext/commands/cog.py | 43 +++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 60c6fb26e..4ab57f091 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -598,7 +598,7 @@ class Command(Generic[GroupT, P, T]): return base - async def _invoke_error_handler(self, interaction: Interaction, error: AppCommandError) -> None: + async def _invoke_error_handlers(self, interaction: Interaction, error: AppCommandError) -> None: # These type ignores are because the type checker can't narrow this type properly. if self.on_error is not None: if self.binding is not None: @@ -613,6 +613,10 @@ class Command(Generic[GroupT, P, T]): if parent.parent is not None: await parent.parent.on_error(interaction, error) + cog_error = getattr(self.binding, '__app_commands_error_handler__', None) + if cog_error is not None: + await cog_error(interaction, error) + def _has_any_error_handlers(self) -> bool: if self.on_error is not None: return True diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 0d66b50fd..5f20060cb 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -992,12 +992,19 @@ class CommandTree(Generic[ClientT]): return [AppCommand(data=d, state=self._state) for d in data] - def _from_interaction(self, interaction: Interaction): + async def _dispatch_error(self, interaction: Interaction, error: AppCommandError, /) -> None: + command = interaction.command + if isinstance(command, Command): + await command._invoke_error_handlers(interaction, error) + else: + await self.on_error(interaction, error) + + def _from_interaction(self, interaction: Interaction) -> None: async def wrapper(): try: await self.call(interaction) except AppCommandError as e: - await self.on_error(interaction, e) + await self._dispatch_error(interaction, e) self.client.loop.create_task(wrapper(), name='CommandTree-invoker') @@ -1167,5 +1174,5 @@ class CommandTree(Generic[ClientT]): try: await command._invoke_with_namespace(interaction, namespace) except AppCommandError as e: - await command._invoke_error_handler(interaction, e) + await command._invoke_error_handlers(interaction, e) await self.on_error(interaction, e) diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 6f14d7ddb..65b691be5 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -28,7 +28,21 @@ import discord from discord import app_commands from discord.utils import maybe_coroutine -from typing import Any, Callable, ClassVar, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + ClassVar, + Coroutine, + Dict, + Generator, + Iterable, + List, + Optional, + TYPE_CHECKING, + Tuple, + TypeVar, + Union, +) from ._types import _BaseCommand, BotT @@ -254,6 +268,9 @@ class Cog(metaclass=CogMeta): __cog_listeners__: List[Tuple[str, str]] __cog_is_app_commands_group__: ClassVar[bool] = False __cog_app_commands_group__: Optional[app_commands.Group] + __app_commands_error_handler__: Optional[ + Callable[[discord.Interaction, app_commands.AppCommandError], Coroutine[Any, Any, None]] + ] def __new__(cls, *args: Any, **kwargs: Any) -> Self: # For issue 426, we need to store a copy of the command objects @@ -329,6 +346,11 @@ class Cog(metaclass=CogMeta): self.__cog_app_commands_group__._children = mapping # type: ignore # Variance issue + if Cog._get_overridden_method(self.cog_app_command_error) is not None: + self.__app_commands_error_handler__ = self.cog_app_command_error + else: + self.__app_commands_error_handler__ = None + return self def get_commands(self) -> List[Command[Self, ..., Any]]: @@ -524,6 +546,25 @@ class Cog(metaclass=CogMeta): """ pass + @_cog_special_method + async def cog_app_command_error(self, interaction: discord.Interaction, error: app_commands.AppCommandError) -> None: + """A special method that is called whenever an error within + an application command is dispatched inside this cog. + + This is similar to :func:`discord.app_commands.CommandTree.on_error` except + only applying to the application commands inside this cog. + + This **must** be a coroutine. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that is being handled. + error: :exc:`~discord.app_commands.AppCommandError` + The exception that was raised. + """ + pass + @_cog_special_method async def cog_before_invoke(self, ctx: Context[BotT]) -> None: """A special method that acts as a cog local pre-invoke hook.