diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index f943a751c..34400178d 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -31,6 +31,7 @@ from typing import ( Any, TYPE_CHECKING, Callable, + Coroutine, Dict, Generator, Generic, @@ -66,6 +67,15 @@ if TYPE_CHECKING: from ..abc import Snowflake from .commands import ContextMenuCallback, CommandCallback, P, T + ErrorFunc = Callable[ + [ + Interaction, + Optional[Union[ContextMenu, Command[Any, ..., Any]]], + AppCommandError, + ], + Coroutine[Any, Any, Any], + ] + __all__ = ('CommandTree',) ClientT = TypeVar('ClientT', bound='Client') @@ -681,6 +691,36 @@ class CommandTree(Generic[ClientT]): traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) + def error(self, coro: ErrorFunc) -> ErrorFunc: + """A decorator that registers a coroutine as a local error handler. + + This must match the signature of the :meth:`on_error` callback. + + The error passed will be derived from :exc:`AppCommandError`. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine or does + not match the signature. + """ + + if not inspect.iscoroutinefunction(coro): + raise TypeError('The error handler must be a coroutine.') + + params = inspect.signature(coro).parameters + if len(params) != 3: + raise TypeError('error handler must have 3 parameters') + + # Type checker doesn't like overriding methods like this + self.on_error = coro # type: ignore + return coro + def command( self, *,