|
|
@ -24,12 +24,12 @@ DEALINGS IN THE SOFTWARE. |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import inspect |
|
|
|
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union |
|
|
|
from typing import Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, Union, overload |
|
|
|
|
|
|
|
|
|
|
|
from .namespace import Namespace |
|
|
|
from .models import AppCommand |
|
|
|
from .commands import Command, Group, _shorten |
|
|
|
from .commands import Command, ContextMenu, Group, _shorten |
|
|
|
from .enums import AppCommandType |
|
|
|
from .errors import CommandAlreadyRegistered, CommandNotFound, CommandSignatureMismatch |
|
|
|
from ..errors import ClientException |
|
|
@ -40,7 +40,7 @@ if TYPE_CHECKING: |
|
|
|
from ..interactions import Interaction |
|
|
|
from ..client import Client |
|
|
|
from ..abc import Snowflake |
|
|
|
from .commands import CommandCallback, P, T |
|
|
|
from .commands import ContextMenuCallback, CommandCallback, P, T |
|
|
|
|
|
|
|
__all__ = ('CommandTree',) |
|
|
|
|
|
|
@ -65,7 +65,7 @@ class CommandTree: |
|
|
|
# The above two mappings can use this structure too but we need fast retrieval |
|
|
|
# by name and guild_id in the above case while here it isn't as important since |
|
|
|
# it's uncommon and N=5 anyway. |
|
|
|
self._context_menus: Dict[Tuple[str, Optional[int], int], Command] = {} |
|
|
|
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} |
|
|
|
|
|
|
|
async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: |
|
|
|
"""|coro| |
|
|
@ -75,6 +75,10 @@ class CommandTree: |
|
|
|
If no guild is passed then global commands are fetched, otherwise |
|
|
|
the guild's commands are fetched instead. |
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
This includes context menu commands. |
|
|
|
|
|
|
|
Parameters |
|
|
|
----------- |
|
|
|
guild: Optional[:class:`abc.Snowflake`] |
|
|
@ -103,7 +107,14 @@ class CommandTree: |
|
|
|
|
|
|
|
return [AppCommand(data=data, state=self._state) for data in commands] |
|
|
|
|
|
|
|
def add_command(self, command: Union[Command, Group], /, *, guild: Optional[Snowflake] = None, override: bool = False): |
|
|
|
def add_command( |
|
|
|
self, |
|
|
|
command: Union[Command, ContextMenu, Group], |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = None, |
|
|
|
override: bool = False, |
|
|
|
): |
|
|
|
"""Adds an application command to the tree. |
|
|
|
|
|
|
|
This only adds the command locally -- in order to sync the commands |
|
|
@ -133,7 +144,20 @@ class CommandTree: |
|
|
|
This is currently 100 for slash commands and 5 for context menu commands. |
|
|
|
""" |
|
|
|
|
|
|
|
if not isinstance(command, (Command, Group)): |
|
|
|
if isinstance(command, ContextMenu): |
|
|
|
guild_id = None if guild is None else guild.id |
|
|
|
type = command.type.value |
|
|
|
key = (command.name, guild_id, type) |
|
|
|
found = key in self._context_menus |
|
|
|
if found and not override: |
|
|
|
raise CommandAlreadyRegistered(command.name, guild_id) |
|
|
|
|
|
|
|
total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type) |
|
|
|
if total + found > 5: |
|
|
|
raise ValueError('maximum number of context menu commands exceeded (5)') |
|
|
|
self._context_menus[key] = command |
|
|
|
return |
|
|
|
elif not isinstance(command, (Command, Group)): |
|
|
|
raise TypeError(f'Expected a application command, received {command.__class__!r} instead') |
|
|
|
|
|
|
|
# todo: validate application command groups having children (required) |
|
|
@ -156,7 +180,36 @@ class CommandTree: |
|
|
|
raise ValueError('maximum number of slash commands exceeded (100)') |
|
|
|
self._global_commands[name] = root |
|
|
|
|
|
|
|
def remove_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]: |
|
|
|
@overload |
|
|
|
def remove_command( |
|
|
|
self, |
|
|
|
command: str, |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = ..., |
|
|
|
type: Literal[AppCommandType.message, AppCommandType.user] = ..., |
|
|
|
) -> Optional[ContextMenu]: |
|
|
|
... |
|
|
|
|
|
|
|
@overload |
|
|
|
def remove_command( |
|
|
|
self, |
|
|
|
command: str, |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = ..., |
|
|
|
type: Literal[AppCommandType.chat_input] = ..., |
|
|
|
) -> Optional[Union[Command, Group]]: |
|
|
|
... |
|
|
|
|
|
|
|
def remove_command( |
|
|
|
self, |
|
|
|
command: str, |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = None, |
|
|
|
type: AppCommandType = AppCommandType.chat_input, |
|
|
|
) -> Optional[Union[Command, ContextMenu, Group]]: |
|
|
|
"""Removes an application command from the tree. |
|
|
|
|
|
|
|
This only removes the command locally -- in order to sync the commands |
|
|
@ -169,30 +222,63 @@ class CommandTree: |
|
|
|
guild: Optional[:class:`abc.Snowflake`] |
|
|
|
The guild to remove the command from. If not given then it |
|
|
|
removes a global command instead. |
|
|
|
type: :class:`AppCommandType` |
|
|
|
The type of command to remove. Defaults to :attr:`AppCommandType.chat_input`, |
|
|
|
i.e. slash commands. |
|
|
|
|
|
|
|
Returns |
|
|
|
--------- |
|
|
|
Optional[Union[:class:`Command`, :class:`Group`]] |
|
|
|
Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]] |
|
|
|
The application command that got removed. |
|
|
|
If nothing was removed then ``None`` is returned instead. |
|
|
|
""" |
|
|
|
|
|
|
|
if guild is None: |
|
|
|
return self._global_commands.pop(command, None) |
|
|
|
else: |
|
|
|
try: |
|
|
|
commands = self._guild_commands[guild.id] |
|
|
|
except KeyError: |
|
|
|
return None |
|
|
|
if type is AppCommandType.chat_input: |
|
|
|
if guild is None: |
|
|
|
return self._global_commands.pop(command, None) |
|
|
|
else: |
|
|
|
return commands.pop(command, None) |
|
|
|
|
|
|
|
def get_command(self, command: str, /, *, guild: Optional[Snowflake] = None) -> Optional[Union[Command, Group]]: |
|
|
|
"""Gets a application command from the tree. |
|
|
|
try: |
|
|
|
commands = self._guild_commands[guild.id] |
|
|
|
except KeyError: |
|
|
|
return None |
|
|
|
else: |
|
|
|
return commands.pop(command, None) |
|
|
|
elif type in (AppCommandType.user, AppCommandType.message): |
|
|
|
guild_id = None if guild is None else guild.id |
|
|
|
key = (command, guild_id, type.value) |
|
|
|
return self._context_menus.pop(key, None) |
|
|
|
|
|
|
|
@overload |
|
|
|
def get_command( |
|
|
|
self, |
|
|
|
command: str, |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = ..., |
|
|
|
type: Literal[AppCommandType.message, AppCommandType.user] = ..., |
|
|
|
) -> Optional[ContextMenu]: |
|
|
|
... |
|
|
|
|
|
|
|
.. note:: |
|
|
|
@overload |
|
|
|
def get_command( |
|
|
|
self, |
|
|
|
command: str, |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = ..., |
|
|
|
type: Literal[AppCommandType.chat_input] = ..., |
|
|
|
) -> Optional[Union[Command, Group]]: |
|
|
|
... |
|
|
|
|
|
|
|
This does *not* include context menu commands. |
|
|
|
def get_command( |
|
|
|
self, |
|
|
|
command: str, |
|
|
|
/, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = None, |
|
|
|
type: AppCommandType = AppCommandType.chat_input, |
|
|
|
) -> Optional[Union[Command, ContextMenu, Group]]: |
|
|
|
"""Gets a application command from the tree. |
|
|
|
|
|
|
|
Parameters |
|
|
|
----------- |
|
|
@ -201,52 +287,103 @@ class CommandTree: |
|
|
|
guild: Optional[:class:`abc.Snowflake`] |
|
|
|
The guild to get the command from. If not given then it |
|
|
|
gets a global command instead. |
|
|
|
type: :class:`AppCommandType` |
|
|
|
The type of command to get. Defaults to :attr:`AppCommandType.chat_input`, |
|
|
|
i.e. slash commands. |
|
|
|
|
|
|
|
Returns |
|
|
|
--------- |
|
|
|
Optional[Union[:class:`Command`, :class:`Group`]] |
|
|
|
Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]] |
|
|
|
The application command that was found. |
|
|
|
If nothing was found then ``None`` is returned instead. |
|
|
|
""" |
|
|
|
|
|
|
|
if guild is None: |
|
|
|
return self._global_commands.get(command) |
|
|
|
else: |
|
|
|
try: |
|
|
|
commands = self._guild_commands[guild.id] |
|
|
|
except KeyError: |
|
|
|
return None |
|
|
|
if type is AppCommandType.chat_input: |
|
|
|
if guild is None: |
|
|
|
return self._global_commands.get(command) |
|
|
|
else: |
|
|
|
return commands.get(command) |
|
|
|
|
|
|
|
def get_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group]]: |
|
|
|
"""Gets all application commands from the tree. |
|
|
|
try: |
|
|
|
commands = self._guild_commands[guild.id] |
|
|
|
except KeyError: |
|
|
|
return None |
|
|
|
else: |
|
|
|
return commands.get(command) |
|
|
|
elif type in (AppCommandType.user, AppCommandType.message): |
|
|
|
guild_id = None if guild is None else guild.id |
|
|
|
key = (command, guild_id, type.value) |
|
|
|
return self._context_menus.get(key) |
|
|
|
|
|
|
|
@overload |
|
|
|
def get_commands( |
|
|
|
self, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = ..., |
|
|
|
type: Literal[AppCommandType.message, AppCommandType.user] = ..., |
|
|
|
) -> List[ContextMenu]: |
|
|
|
... |
|
|
|
|
|
|
|
.. note:: |
|
|
|
@overload |
|
|
|
def get_commands( |
|
|
|
self, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = ..., |
|
|
|
type: Literal[AppCommandType.chat_input] = ..., |
|
|
|
) -> List[Union[Command, Group]]: |
|
|
|
... |
|
|
|
|
|
|
|
This does *not* retrieve context menu commands. |
|
|
|
def get_commands( |
|
|
|
self, |
|
|
|
*, |
|
|
|
guild: Optional[Snowflake] = None, |
|
|
|
type: AppCommandType = AppCommandType.chat_input, |
|
|
|
) -> Union[List[Union[Command, Group]], List[ContextMenu]]: |
|
|
|
"""Gets all application commands from the tree. |
|
|
|
|
|
|
|
Parameters |
|
|
|
----------- |
|
|
|
guild: Optional[:class:`~discord.abc.Snowflake`] |
|
|
|
The guild to get the commands from. If not given then it |
|
|
|
gets all global commands instead. |
|
|
|
type: :class:`AppCommandType` |
|
|
|
The type of commands to get. Defaults to :attr:`AppCommandType.chat_input`, |
|
|
|
i.e. slash commands. |
|
|
|
|
|
|
|
Returns |
|
|
|
--------- |
|
|
|
List[Union[:class:`Command`, :class:`Group`]] |
|
|
|
Union[List[:class:`ContextMenu`], List[Union[:class:`Command`, :class:`Group`]] |
|
|
|
The application commands from the tree. |
|
|
|
""" |
|
|
|
|
|
|
|
if type is AppCommandType.chat_input: |
|
|
|
if guild is None: |
|
|
|
return list(self._global_commands.values()) |
|
|
|
else: |
|
|
|
try: |
|
|
|
commands = self._guild_commands[guild.id] |
|
|
|
except KeyError: |
|
|
|
return [] |
|
|
|
else: |
|
|
|
return list(commands.values()) |
|
|
|
else: |
|
|
|
guild_id = None if guild is None else guild.id |
|
|
|
value = type.value |
|
|
|
return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value] |
|
|
|
|
|
|
|
def _get_all_commands(self, *, guild: Optional[Snowflake] = None) -> List[Union[Command, Group, ContextMenu]]: |
|
|
|
if guild is None: |
|
|
|
return list(self._global_commands.values()) |
|
|
|
base: List[Union[Command, Group, ContextMenu]] = list(self._global_commands.values()) |
|
|
|
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None) |
|
|
|
return base |
|
|
|
else: |
|
|
|
try: |
|
|
|
commands = self._guild_commands[guild.id] |
|
|
|
except KeyError: |
|
|
|
return [] |
|
|
|
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None] |
|
|
|
else: |
|
|
|
return list(commands.values()) |
|
|
|
base: List[Union[Command, Group, ContextMenu]] = list(commands.values()) |
|
|
|
guild_id = guild.id |
|
|
|
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id) |
|
|
|
return base |
|
|
|
|
|
|
|
def command( |
|
|
|
self, |
|
|
@ -266,7 +403,7 @@ class CommandTree: |
|
|
|
The description of the application command. This shows up in the UI to describe |
|
|
|
the application command. If not given, it defaults to the first line of the docstring |
|
|
|
of the callback shortened to 100 characters. |
|
|
|
guild: Optional[:class:`Snowflake`] |
|
|
|
guild: Optional[:class:`.abc.Snowflake`] |
|
|
|
The guild to add the command to. If not given then it |
|
|
|
becomes a global command instead. |
|
|
|
""" |
|
|
@ -287,7 +424,6 @@ class CommandTree: |
|
|
|
name=name if name is not MISSING else func.__name__, |
|
|
|
description=desc, |
|
|
|
callback=func, |
|
|
|
type=AppCommandType.chat_input, |
|
|
|
parent=None, |
|
|
|
) |
|
|
|
self.add_command(command, guild=guild) |
|
|
@ -295,6 +431,49 @@ class CommandTree: |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
def context_menu( |
|
|
|
self, *, name: str = MISSING, guild: Optional[Snowflake] = None |
|
|
|
) -> Callable[[ContextMenuCallback], ContextMenu]: |
|
|
|
"""Creates a application command context menu from a regular function directly under this tree. |
|
|
|
|
|
|
|
This function must have a signature of :class:`~discord.Interaction` as its first parameter |
|
|
|
and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`, |
|
|
|
or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter. |
|
|
|
|
|
|
|
Examples |
|
|
|
--------- |
|
|
|
|
|
|
|
.. code-block:: python3 |
|
|
|
|
|
|
|
@app_commands.context_menu() |
|
|
|
async def react(interaction: discord.Interaction, message: discord.Message): |
|
|
|
await interaction.response.send_message('Very cool message!', ephemeral=True) |
|
|
|
|
|
|
|
@app_commands.context_menu() |
|
|
|
async def ban(interaction: discord.Interaction, user: discord.Member): |
|
|
|
await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True) |
|
|
|
|
|
|
|
Parameters |
|
|
|
------------ |
|
|
|
name: :class:`str` |
|
|
|
The name of the context menu command. If not given, it defaults to a title-case |
|
|
|
version of the callback name. Note that unlike regular slash commands this can |
|
|
|
have spaces and upper case characters in the name. |
|
|
|
guild: Optional[:class:`.abc.Snowflake`] |
|
|
|
The guild to add the command to. If not given then it |
|
|
|
becomes a global command instead. |
|
|
|
""" |
|
|
|
|
|
|
|
def decorator(func: ContextMenuCallback) -> ContextMenu: |
|
|
|
if not inspect.iscoroutinefunction(func): |
|
|
|
raise TypeError('context menu function must be a coroutine function') |
|
|
|
|
|
|
|
context_menu = ContextMenu._from_decorator(func, name=name) |
|
|
|
self.add_command(context_menu, guild=guild) |
|
|
|
return context_menu |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
async def sync(self, *, guild: Optional[Snowflake]) -> List[AppCommand]: |
|
|
|
"""|coro| |
|
|
|
|
|
|
@ -327,7 +506,7 @@ class CommandTree: |
|
|
|
if self.client.application_id is None: |
|
|
|
raise ClientException('Client does not have an application ID set') |
|
|
|
|
|
|
|
commands = self.get_commands(guild=guild) |
|
|
|
commands = self._get_all_commands(guild=guild) |
|
|
|
payload = [command.to_dict() for command in commands] |
|
|
|
if guild is None: |
|
|
|
data = await self._http.bulk_upsert_global_commands(self.client.application_id, payload=payload) |
|
|
@ -345,6 +524,25 @@ class CommandTree: |
|
|
|
|
|
|
|
self.client.loop.create_task(wrapper(), name='CommandTree-invoker') |
|
|
|
|
|
|
|
async def _call_context_menu(self, interaction: Interaction, data: ApplicationCommandInteractionData, type: int): |
|
|
|
name = data['name'] |
|
|
|
guild_id = interaction.guild_id |
|
|
|
ctx_menu = self._context_menus.get((name, guild_id, type)) |
|
|
|
if ctx_menu is None: |
|
|
|
raise CommandNotFound(name, [], AppCommandType(type)) |
|
|
|
|
|
|
|
resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) |
|
|
|
# This will always work at runtime |
|
|
|
value = resolved.get(data.get('target_id')) # type: ignore |
|
|
|
if ctx_menu.type.value != type: |
|
|
|
raise CommandSignatureMismatch(ctx_menu) |
|
|
|
|
|
|
|
if value is None: |
|
|
|
raise RuntimeError('This should not happen if Discord sent well-formed data.') |
|
|
|
|
|
|
|
# I assume I don't have to type check here. |
|
|
|
await ctx_menu._invoke(interaction, value) |
|
|
|
|
|
|
|
async def call(self, interaction: Interaction): |
|
|
|
"""|coro| |
|
|
|
|
|
|
@ -367,6 +565,12 @@ class CommandTree: |
|
|
|
application command definition. |
|
|
|
""" |
|
|
|
data: ApplicationCommandInteractionData = interaction.data # type: ignore |
|
|
|
type = data.get('type', 1) |
|
|
|
if type != 1: |
|
|
|
# Context menu command... |
|
|
|
await self._call_context_menu(interaction, data, type) |
|
|
|
return |
|
|
|
|
|
|
|
parents: List[str] = [] |
|
|
|
name = data['name'] |
|
|
|
command = self._global_commands.get(name) |
|
|
|