diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index ac902bb8b..c326f979f 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -56,6 +56,7 @@ from ..utils import resolve_annotation, MISSING, is_inside_class from ..user import User from ..member import Member from ..role import Role +from ..message import Message from ..mixins import Hashable from ..permissions import Permissions @@ -74,6 +75,7 @@ if TYPE_CHECKING: __all__ = ( 'CommandParameter', 'Command', + 'ContextMenu', 'Group', 'command', 'describe', @@ -88,6 +90,18 @@ T = TypeVar('T') GroupT = TypeVar('GroupT', bound='Group') Coro = Coroutine[Any, Any, T] +ContextMenuCallback = Union[ + # If groups end up support context menus these would be uncommented + # Callable[[GroupT, Interaction, Member], Coro[Any]], + # Callable[[GroupT, Interaction, User], Coro[Any]], + # Callable[[GroupT, Interaction, Message], Coro[Any]], + # Callable[[GroupT, Interaction, Union[Member, User]], Coro[Any]], + Callable[[Interaction, Member], Coro[Any]], + Callable[[Interaction, User], Coro[Any]], + Callable[[Interaction, Message], Coro[Any]], + Callable[[Interaction, Union[Member, User]], Coro[Any]], +] + if TYPE_CHECKING: CommandCallback = Union[ Callable[Concatenate[GroupT, Interaction, P], Coro[T]], @@ -149,8 +163,7 @@ class CommandParameter: min_value: Optional[int] = None max_value: Optional[int] = None autocomplete: bool = MISSING - annotation: Any = MISSING - # restrictor: Optional[RestrictorType] = None + _annotation: Any = MISSING def to_dict(self) -> Dict[str, Any]: base = { @@ -231,7 +244,7 @@ def _annotation_to_type( # Check if there's an origin origin = getattr(annotation, '__origin__', None) - if origin is not Union: # TODO: Python 3.10 + if origin is not Union: # Only Union/Optional is supported so bail early raise TypeError(f'unsupported type annotation {annotation!r}') @@ -264,6 +277,31 @@ def _annotation_to_type( return (AppCommandOptionType.mentionable, default) +def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType: + if annotation is Message: + return AppCommandType.message + + supported_types: Set[Any] = {Member, User} + if annotation in supported_types: + return AppCommandType.user + + # Check if there's an origin + origin = getattr(annotation, '__origin__', None) + if origin is not Union: + # Only Union is supported so bail early + msg = ( + f'unsupported type annotation {annotation!r}, must be either discord.Member, ' + 'discord.User, discord.Message, or a typing.Union of discord.Member and discord.User' + ) + raise TypeError(msg) + + # Only Union[Member, User] is supported + if not all(arg in supported_types for arg in annotation.__args__): + raise TypeError(f'unsupported types given inside {annotation!r}') + + return AppCommandType.user + + def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Dict[str, Any]) -> None: for name, param in params.items(): description = descriptions.pop(name, MISSING) @@ -304,7 +342,7 @@ def _get_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandPara if not isinstance(result.default, valid_types): raise TypeError(f'invalid default parameter type given ({result.default.__class__}), expected {valid_types}') - result.annotation = annotation + result._annotation = annotation return result @@ -341,6 +379,31 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s return result +def _get_context_menu_parameter(func: ContextMenuCallback) -> Tuple[str, Any, AppCommandType]: + params = inspect.signature(func).parameters + if len(params) != 2: + msg = ( + 'context menu callbacks require 2 parameters, the first one being the annotation and the ' + 'other one explicitly annotated with either discord.Message, discord.User, discord.Member, ' + 'or a typing.Union of discord.Member and discord.User' + ) + raise TypeError(msg) + + iterator = iter(params.values()) + next(iterator) # skip interaction + parameter = next(iterator) + if parameter.annotation is parameter.empty: + msg = ( + 'second parameter of context menu callback must be explicitly annotated with either discord.Message, ' + 'discord.User, discord.Member, or a typing.Union of discord.Member and discord.User' + ) + raise TypeError(msg) + + resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {}) + type = _context_menu_annotation(resolved) + return (parameter.name, resolved, type) + + class Command(Generic[GroupT, P, T]): """A class that implements an application command. @@ -349,6 +412,7 @@ class Command(Generic[GroupT, P, T]): - :func:`~discord.app_commands.command` - :meth:`Group.command ` + - :meth:`CommandTree.command ` .. versionadded:: 2.0 @@ -356,8 +420,6 @@ class Command(Generic[GroupT, P, T]): ------------ name: :class:`str` The name of the application command. - type: :class:`AppCommandType` - The type of application command. callback: :ref:`coroutine ` The coroutine that is executed when the command is called. description: :class:`str` @@ -373,7 +435,6 @@ class Command(Generic[GroupT, P, T]): name: str, description: str, callback: CommandCallback[GroupT, P, T], - type: AppCommandType = AppCommandType.chat_input, parent: Optional[Group] = None, ): self.name: str = name @@ -381,7 +442,6 @@ class Command(Generic[GroupT, P, T]): self._callback: CommandCallback[GroupT, P, T] = callback self.parent: Optional[Group] = parent self.binding: Optional[GroupT] = None - self.type: AppCommandType = type self._params: Dict[str, CommandParameter] = _extract_parameters_from_callback(callback, callback.__globals__) def _copy_with_binding(self, binding: GroupT) -> Command: @@ -391,7 +451,6 @@ class Command(Generic[GroupT, P, T]): copy.description = self.description copy._callback = self._callback copy.parent = self.parent - copy.type = self.type copy._params = self._params.copy() copy.binding = binding return copy @@ -399,7 +458,7 @@ class Command(Generic[GroupT, P, T]): def to_dict(self) -> Dict[str, Any]: # If we have a parent then our type is a subcommand # Otherwise, the type falls back to the specific command type (e.g. slash command or context menu) - option_type = self.type.value if self.parent is None else AppCommandOptionType.subcommand.value + option_type = AppCommandType.chat_input.value if self.parent is None else AppCommandOptionType.subcommand.value return { 'name': self.name, 'description': self.description, @@ -431,20 +490,8 @@ class Command(Generic[GroupT, P, T]): raise CommandSignatureMismatch(self) from None raise - def get_parameter(self, name: str) -> Optional[CommandParameter]: - """Returns the :class:`CommandParameter` with the given name. - - Parameters - ----------- - name: :class:`str` - The parameter name to get. - - Returns - -------- - Optional[:class:`CommandParameter`] - The command parameter, if found. - """ - return self._params.get(name) + def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]: + return None @property def root_parent(self) -> Optional[Group]: @@ -454,8 +501,64 @@ class Command(Generic[GroupT, P, T]): parent = self.parent return parent.parent or parent - def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]: - return None + +class ContextMenu: + """A class that implements a context menu application command. + + These are usually not created manually, instead they are created using + one of the following decorators: + + - :func:`~discord.app_commands.context_menu` + - :meth:`CommandTree.command ` + + .. versionadded:: 2.0 + + Attributes + ------------ + name: :class:`str` + The name of the context menu. + callback: :ref:`coroutine ` + The coroutine that is executed when the context menu is called. + type: :class:`.AppCommandType` + The type of context menu application command. + """ + + def __init__( + self, + *, + name: str, + callback: ContextMenuCallback, + type: AppCommandType, + ): + self.name: str = name + self._callback: ContextMenuCallback = callback + self.type: AppCommandType = type + (param, annotation, actual_type) = _get_context_menu_parameter(callback) + if actual_type != type: + raise ValueError(f'context menu callback implies a type of {actual_type} but {type} was passed.') + self._param_name = param + self._annotation = annotation + + @classmethod + def _from_decorator(cls, callback: ContextMenuCallback, *, name: str = MISSING) -> ContextMenu: + (param, annotation, type) = _get_context_menu_parameter(callback) + + self = cls.__new__(cls) + self.name = callback.__name__.title() if name is MISSING else name + self._callback = callback + self.type = type + self._param_name = param + self._annotation = annotation + return self + + def to_dict(self) -> Dict[str, Any]: + return { + 'name': self.name, + 'type': self.type.value, + } + + async def _invoke(self, interaction: Interaction, arg: Any): + await self._callback(interaction, arg) class Group: @@ -581,8 +684,13 @@ class Group: attribute will always be ``None`` in this case. ValueError There are too many commands already registered. + TypeError + The wrong command type was passed. """ + if not isinstance(command, (Command, Group)): + raise TypeError(f'expected Command or Group not {command.__class__!r}') + if not override and command.name in self._children: raise CommandAlreadyRegistered(command.name, guild_id=None) @@ -658,7 +766,6 @@ class Group: name=name if name is not MISSING else func.__name__, description=desc, callback=func, - type=AppCommandType.chat_input, parent=self, ) self.add_command(command) @@ -701,13 +808,49 @@ def command( name=name if name is not MISSING else func.__name__, description=desc, callback=func, - type=AppCommandType.chat_input, parent=None, ) return decorator +def context_menu(*, name: str = MISSING) -> Callable[[ContextMenuCallback], ContextMenu]: + """Creates a application command context menu from a regular function. + + This function must have a signature of :class:`~discord.Interaction` as its first parameter + and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`, + or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter. + + Examples + --------- + + .. code-block:: python3 + + @app_commands.context_menu() + async def react(interaction: discord.Interaction, message: discord.Message): + await interaction.response.send_message('Very cool message!', ephemeral=True) + + @app_commands.context_menu() + async def ban(interaction: discord.Interaction, user: discord.Member): + await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True) + + Parameters + ------------ + name: :class:`str` + The name of the context menu command. If not given, it defaults to a title-case + version of the callback name. Note that unlike regular slash commands this can + have spaces and upper case characters in the name. + """ + + def decorator(func: ContextMenuCallback) -> ContextMenu: + if not inspect.iscoroutinefunction(func): + raise TypeError('context menu function must be a coroutine function') + + return ContextMenu._from_decorator(func, name=name) + + return decorator + + def describe(**parameters: str) -> Callable[[T], T]: r"""Describes the given parameters by their name using the key of the keyword argument as the name. diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py index 919004e27..086e5e7f6 100644 --- a/discord/app_commands/errors.py +++ b/discord/app_commands/errors.py @@ -25,6 +25,8 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import TYPE_CHECKING, List, Optional, Union + +from .enums import AppCommandType from ..errors import DiscordException __all__ = ( @@ -34,7 +36,7 @@ __all__ = ( ) if TYPE_CHECKING: - from .commands import Command, Group + from .commands import Command, Group, ContextMenu class CommandAlreadyRegistered(DiscordException): @@ -50,8 +52,8 @@ class CommandAlreadyRegistered(DiscordException): """ def __init__(self, name: str, guild_id: Optional[int]): - self.name = name - self.guild_id = guild_id + self.name: str = name + self.guild_id: Optional[int] = guild_id super().__init__(f'Command {name!r} already registered.') @@ -65,11 +67,14 @@ class CommandNotFound(DiscordException): parents: List[:class:`str`] A list of parent command names that were previously found prior to the application command not being found. + type: :class:`AppCommandType` + The type of command that was not found. """ - def __init__(self, name: str, parents: List[str]): - self.name = name - self.parents = parents + def __init__(self, name: str, parents: List[str], type: AppCommandType = AppCommandType.chat_input): + self.name: str = name + self.parents: List[str] = parents + self.type: AppCommandType = type super().__init__(f'Application command {name!r} not found') @@ -81,14 +86,14 @@ class CommandSignatureMismatch(DiscordException): Attributes ------------ - command: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.Group`] + command: Union[:class:`~.app_commands.Command`, :class:`~.app_commands.ContextMenu`, :class:`~.app_commands.Group`] The command that had the signature mismatch. """ - def __init__(self, command: Union[Command, Group]): - self.command: Union[Command, Group] = command + def __init__(self, command: Union[Command, ContextMenu, Group]): + self.command: Union[Command, ContextMenu, Group] = command msg = ( - f'The signature for command {command!r} is different from the one provided by Discord. ' + f'The signature for command {command.name!r} is different from the one provided by Discord. ' 'This can happen because either your code is out of date or you have not synced the ' 'commands with Discord, causing the mismatch in data. It is recommended to sync the ' 'command tree to fix this issue.' diff --git a/discord/app_commands/namespace.py b/discord/app_commands/namespace.py index 25e85388e..a81a633b3 100644 --- a/discord/app_commands/namespace.py +++ b/discord/app_commands/namespace.py @@ -30,6 +30,7 @@ from ..member import Member from ..object import Object from ..role import Role from ..message import Message, Attachment +from ..channel import PartialMessageable from .models import AppCommandChannel, AppCommandThread if TYPE_CHECKING: @@ -86,6 +87,27 @@ class Namespace: resolved: ResolvedData, options: List[ApplicationCommandInteractionDataOption], ): + completed = self._get_resolved_items(interaction, resolved) + for option in options: + opt_type = option['type'] + name = option['name'] + if opt_type in (3, 4, 5): # string, integer, boolean + value = option['value'] # type: ignore -- Key is there + self.__dict__[name] = value + elif opt_type == 10: # number + value = option['value'] # type: ignore -- Key is there + if value is None: + self.__dict__[name] = float('nan') + else: + self.__dict__[name] = float(value) + elif opt_type in (6, 7, 8, 9, 11): + # Remaining ones should be snowflake based ones with resolved data + snowflake: str = option['value'] # type: ignore -- Key is there + value = completed.get(snowflake) + self.__dict__[name] = value + + @classmethod + def _get_resolved_items(cls, interaction: Interaction, resolved: ResolvedData) -> Dict[str, Any]: completed: Dict[str, Any] = {} state = interaction._state members = resolved.get('members', {}) @@ -126,25 +148,18 @@ class Namespace: } ) - # TODO: messages + guild = state._get_guild(guild_id) + for (message_id, message_data) in resolved.get('messages', {}).items(): + channel_id = int(message_data['channel_id']) + if guild is None: + channel = PartialMessageable(state=state, id=channel_id) + else: + channel = guild.get_channel_or_thread(channel_id) or PartialMessageable(state=state, id=channel_id) - for option in options: - opt_type = option['type'] - name = option['name'] - if opt_type in (3, 4, 5): # string, integer, boolean - value = option['value'] # type: ignore -- Key is there - self.__dict__[name] = value - elif opt_type == 10: # number - value = option['value'] # type: ignore -- Key is there - if value is None: - self.__dict__[name] = float('nan') - else: - self.__dict__[name] = float(value) - elif opt_type in (6, 7, 8, 9, 11): - # Remaining ones should be snowflake based ones with resolved data - snowflake: str = option['value'] # type: ignore -- Key is there - value = completed.get(snowflake) - self.__dict__[name] = value + # Type checker doesn't understand this due to failure to narrow + completed[message_id] = Message(state=state, channel=channel, data=message_data) # type: ignore + + return completed def __repr__(self) -> str: items = (f'{k}={v!r}' for k, v in self.__dict__.items()) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 6d7e63398..1fdf0aa0a 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -24,12 +24,12 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import inspect -from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, Union, overload from .namespace import Namespace from .models import AppCommand -from .commands import Command, Group, _shorten +from .commands import Command, ContextMenu, Group, _shorten from .enums import AppCommandType from .errors import CommandAlreadyRegistered, CommandNotFound, CommandSignatureMismatch from ..errors import ClientException @@ -40,7 +40,7 @@ if TYPE_CHECKING: from ..interactions import Interaction from ..client import Client from ..abc import Snowflake - from .commands import CommandCallback, P, T + from .commands import ContextMenuCallback, CommandCallback, P, T __all__ = ('CommandTree',) @@ -65,7 +65,7 @@ class CommandTree: # The above two mappings can use this structure too but we need fast retrieval # by name and guild_id in the above case while here it isn't as important since # it's uncommon and N=5 anyway. - self._context_menus: Dict[Tuple[str, Optional[int], int], Command] = {} + self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: """|coro| @@ -75,6 +75,10 @@ class CommandTree: If no guild is passed then global commands are fetched, otherwise the guild's commands are fetched instead. + .. note:: + + This includes context menu commands. + Parameters ----------- guild: Optional[:class:`abc.Snowflake`] @@ -103,7 +107,14 @@ class CommandTree: return [AppCommand(data=data, state=self._state) for data in commands] - def add_command(self, command: Union[Command, Group], /, *, guild: Optional[Snowflake] = None, override: bool = False): + def add_command( + self, + command: Union[Command, ContextMenu, Group], + /, + *, + guild: Optional[Snowflake] = None, + override: bool = False, + ): """Adds an application command to the tree. This only adds the command locally -- in order to sync the commands @@ -133,7 +144,20 @@ class CommandTree: This is currently 100 for slash commands and 5 for context menu commands. """ - if not isinstance(command, (Command, Group)): + if isinstance(command, ContextMenu): + guild_id = None if guild is None else guild.id + type = command.type.value + key = (command.name, guild_id, type) + found = key in self._context_menus + if found and not override: + raise CommandAlreadyRegistered(command.name, guild_id) + + total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type) + if total + found > 5: + raise ValueError('maximum number of context menu commands exceeded (5)') + self._context_menus[key] = command + return + elif not isinstance(command, (Command, Group)): raise TypeError(f'Expected a application command, received {command.__class__!r} instead') # todo: validate application command groups having children (required) @@ -156,7 +180,36 @@ class CommandTree: raise ValueError('maximum number of slash commands exceeded (100)') self._global_commands[name] = root - def remove_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]: + @overload + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user] = ..., + ) -> Optional[ContextMenu]: + ... + + @overload + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input] = ..., + ) -> Optional[Union[Command, Group]]: + ... + + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = None, + type: AppCommandType = AppCommandType.chat_input, + ) -> Optional[Union[Command, ContextMenu, Group]]: """Removes an application command from the tree. This only removes the command locally -- in order to sync the commands @@ -169,30 +222,63 @@ class CommandTree: guild: Optional[:class:`abc.Snowflake`] The guild to remove the command from. If not given then it removes a global command instead. + type: :class:`AppCommandType` + The type of command to remove. Defaults to :attr:`AppCommandType.chat_input`, + i.e. slash commands. Returns --------- - Optional[Union[:class:`Command`, :class:`Group`]] + Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]] The application command that got removed. If nothing was removed then ``None`` is returned instead. """ - if guild is None: - return self._global_commands.pop(command, None) - else: - try: - commands = self._guild_commands[guild.id] - except KeyError: - return None + if type is AppCommandType.chat_input: + if guild is None: + return self._global_commands.pop(command, None) else: - return commands.pop(command, None) - - def get_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]: - """Gets a application command from the tree. + try: + commands = self._guild_commands[guild.id] + except KeyError: + return None + else: + return commands.pop(command, None) + elif type in (AppCommandType.user, AppCommandType.message): + guild_id = None if guild is None else guild.id + key = (command, guild_id, type.value) + return self._context_menus.pop(key, None) + + @overload + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user] = ..., + ) -> Optional[ContextMenu]: + ... - .. note:: + @overload + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input] = ..., + ) -> Optional[Union[Command, Group]]: + ... - This does *not* include context menu commands. + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = None, + type: AppCommandType = AppCommandType.chat_input, + ) -> Optional[Union[Command, ContextMenu, Group]]: + """Gets a application command from the tree. Parameters ----------- @@ -201,52 +287,103 @@ class CommandTree: guild: Optional[:class:`abc.Snowflake`] The guild to get the command from. If not given then it gets a global command instead. + type: :class:`AppCommandType` + The type of command to get. Defaults to :attr:`AppCommandType.chat_input`, + i.e. slash commands. Returns --------- - Optional[Union[:class:`Command`, :class:`Group`]] + Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]] The application command that was found. If nothing was found then ``None`` is returned instead. """ - if guild is None: - return self._global_commands.get(command) - else: - try: - commands = self._guild_commands[guild.id] - except KeyError: - return None + if type is AppCommandType.chat_input: + if guild is None: + return self._global_commands.get(command) else: - return commands.get(command) - - def get_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group]]: - """Gets all application commands from the tree. + try: + commands = self._guild_commands[guild.id] + except KeyError: + return None + else: + return commands.get(command) + elif type in (AppCommandType.user, AppCommandType.message): + guild_id = None if guild is None else guild.id + key = (command, guild_id, type.value) + return self._context_menus.get(key) + + @overload + def get_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user] = ..., + ) -> List[ContextMenu]: + ... - .. note:: + @overload + def get_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input] = ..., + ) -> List[Union[Command, Group]]: + ... - This does *not* retrieve context menu commands. + def get_commands( + self, + *, + guild: Optional[Snowflake] = None, + type: AppCommandType = AppCommandType.chat_input, + ) -> Union[List[Union[Command, Group]], List[ContextMenu]]: + """Gets all application commands from the tree. Parameters ----------- guild: Optional[:class:`~discord.abc.Snowflake`] The guild to get the commands from. If not given then it gets all global commands instead. + type: :class:`AppCommandType` + The type of commands to get. Defaults to :attr:`AppCommandType.chat_input`, + i.e. slash commands. Returns --------- - List[Union[:class:`Command`, :class:`Group`]] + Union[List[:class:`ContextMenu`], List[Union[:class:`Command`, :class:`Group`]] The application commands from the tree. """ + if type is AppCommandType.chat_input: + if guild is None: + return list(self._global_commands.values()) + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + return [] + else: + return list(commands.values()) + else: + guild_id = None if guild is None else guild.id + value = type.value + return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value] + + def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]: if guild is None: - return list(self._global_commands.values()) + base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values()) + base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None) + return base else: try: commands = self._guild_commands[guild.id] except KeyError: - return [] + return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None] else: - return list(commands.values()) + base: List[Union[Command, Group, ContextMenu]] = list(commands.values()) + guild_id = guild.id + base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id) + return base def command( self, @@ -266,7 +403,7 @@ class CommandTree: The description of the application command. This shows up in the UI to describe the application command. If not given, it defaults to the first line of the docstring of the callback shortened to 100 characters. - guild: Optional[:class:`Snowflake`] + guild: Optional[:class:`.abc.Snowflake`] The guild to add the command to. If not given then it becomes a global command instead. """ @@ -287,7 +424,6 @@ class CommandTree: name=name if name is not MISSING else func.__name__, description=desc, callback=func, - type=AppCommandType.chat_input, parent=None, ) self.add_command(command, guild=guild) @@ -295,6 +431,49 @@ class CommandTree: return decorator + def context_menu( + self, *, name: str = MISSING, guild: Optional[Snowflake] = None + ) -> Callable[[ContextMenuCallback], ContextMenu]: + """Creates a application command context menu from a regular function directly under this tree. + + This function must have a signature of :class:`~discord.Interaction` as its first parameter + and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`, + or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter. + + Examples + --------- + + .. code-block:: python3 + + @app_commands.context_menu() + async def react(interaction: discord.Interaction, message: discord.Message): + await interaction.response.send_message('Very cool message!', ephemeral=True) + + @app_commands.context_menu() + async def ban(interaction: discord.Interaction, user: discord.Member): + await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True) + + Parameters + ------------ + name: :class:`str` + The name of the context menu command. If not given, it defaults to a title-case + version of the callback name. Note that unlike regular slash commands this can + have spaces and upper case characters in the name. + guild: Optional[:class:`.abc.Snowflake`] + The guild to add the command to. If not given then it + becomes a global command instead. + """ + + def decorator(func: ContextMenuCallback) -> ContextMenu: + if not inspect.iscoroutinefunction(func): + raise TypeError('context menu function must be a coroutine function') + + context_menu = ContextMenu._from_decorator(func, name=name) + self.add_command(context_menu, guild=guild) + return context_menu + + return decorator + async def sync(self, *, guild: Optional[Snowflake]) -> List[AppCommand]: """|coro| @@ -327,7 +506,7 @@ class CommandTree: if self.client.application_id is None: raise ClientException('Client does not have an application ID set') - commands = self.get_commands(guild=guild) + commands = self._get_all_commands(guild=guild) payload = [command.to_dict() for command in commands] if guild is None: data = await self._http.bulk_upsert_global_commands(self.client.application_id, payload=payload) @@ -345,6 +524,25 @@ class CommandTree: self.client.loop.create_task(wrapper(), name='CommandTree-invoker') + async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int): + name = data['name'] + guild_id = interaction.guild_id + ctx_menu = self._context_menus.get((name, guild_id, type)) + if ctx_menu is None: + raise CommandNotFound(name, [], AppCommandType(type)) + + resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) + # This will always work at runtime + value = resolved.get(data.get('target_id')) # type: ignore + if ctx_menu.type.value != type: + raise CommandSignatureMismatch(ctx_menu) + + if value is None: + raise RuntimeError('This should not happen if Discord sent well-formed data.') + + # I assume I don't have to type check here. + await ctx_menu._invoke(interaction, value) + async def call(self, interaction: Interaction): """|coro| @@ -367,6 +565,12 @@ class CommandTree: application command definition. """ data: ApplicationCommandInteractionData = interaction.data # type: ignore + type = data.get('type', 1) + if type != 1: + # Context menu command... + await self._call_context_menu(interaction, data, type) + return + parents: List[str] = [] name = data['name'] command = self._global_commands.get(name)