Browse Source

Store command ids on tree

pull/9924/head
Soheab_ 8 months ago
parent
commit
d2015c1053
  1. 108
      discord/app_commands/tree.py

108
discord/app_commands/tree.py

@ -161,6 +161,9 @@ class CommandTree(Generic[ClientT]):
# it's uncommon and N=5 anyway.
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {}
self._global_command_ids: Dict[str, int] = {}
self._guild_command_ids: Dict[Tuple[str, int], int] = {}
async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand:
"""|coro|
@ -198,7 +201,9 @@ class CommandTree(Generic[ClientT]):
else:
command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id)
return AppCommand(data=command, state=self._state)
res = AppCommand(data=command, state=self._state)
self._store_command_id((res, res.id))
return res
async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
"""|coro|
@ -238,7 +243,79 @@ class CommandTree(Generic[ClientT]):
else:
commands = await self._http.get_guild_commands(self.client.application_id, guild.id)
return [AppCommand(data=data, state=self._state) for data in commands]
res = [AppCommand(data=command, state=self._state) for command in commands]
self._store_command_id(*((cmd, cmd.id) for cmd in res))
return res
def get_command_id(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> Optional[int]:
"""Gets the command ID for a command.
Parameters
-----------
name: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`]
The name of the command to get the ID for.
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the command ID for. If not passed then the global command
ID is fetched instead.
Returns
--------
Optional[:class:`~discord.app_commands.CommandID`]
The command ID if found, otherwise ``None``.
.. note::
Group commands will return the ID of the root command. Subcommands do not have their own IDs.
"""
name: Optional[str] = None
if isinstance(command, AppCommand):
return command.id
elif isinstance(command, (Command, Group)):
name = (command.root_parent or command).name
elif isinstance(command, ContextMenu):
name = command.name
elif isinstance(command, str):
name = command.split()[0]
return self._global_command_ids.get(name) if guild is None else self._guild_command_ids.get((name, guild.id))
def get_command_mention(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> str | None:
"""Gets the mention string for a command.
Parameters
-----------
command: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`]
The command to get the mention string for.
Returns
--------
Optional[:class:`str`]
The mention string for the command if found, otherwise ``None``.
.. note::
Remember that groups cannot be mentioned, only with a subcommand.
"""
if isinstance(command, AppCommand):
return command.mention
command_id = self.get_command_id(command, guild=guild)
if command_id is None:
return None
if isinstance(command, (Command, Group)):
full_name = command.qualified_name
elif isinstance(command, ContextMenu):
full_name = command.name
elif isinstance(command, str):
full_name = command
return f'</{full_name}:{command_id}>'
def copy_global_to(self, *, guild: Snowflake) -> None:
"""Copies all global commands to the specified guild.
@ -1134,7 +1211,29 @@ class CommandTree(Generic[ClientT]):
raise CommandSyncFailure(e, commands) from None
raise
return [AppCommand(data=d, state=self._state) for d in data]
res = [AppCommand(data=d, state=self._state) for d in data]
self._store_command_id(*((cmd, cmd.id) for cmd in res))
return res
def _store_command_id(self, *commands: Tuple[AppCommand | ContextMenu | Command[Any, ..., Any] | Group, int]) -> None:
for command, command_id in commands:
if isinstance(command, AppCommand):
guild_id = command.guild_id
if guild_id is None:
self._global_command_ids[command.name] = command_id
else:
key = (command.name, guild_id)
self._guild_command_ids[key] = command_id
else:
guild_ids = command._guild_ids
name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name
if not guild_ids:
self._global_command_ids[name] = command_id
else:
for guild_id in guild_ids:
key = (name, guild_id)
self._guild_command_ids[key] = command_id
async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None:
command = interaction.command
@ -1231,6 +1330,8 @@ class CommandTree(Generic[ClientT]):
if ctx_menu is None:
raise CommandNotFound(name, [], AppCommandType(type))
self._store_command_id((ctx_menu, int(data['id'])))
resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {}))
# This is annotated as str | int but realistically this will always be str
@ -1281,6 +1382,7 @@ class CommandTree(Generic[ClientT]):
return
command, options = self._get_app_command_options(data)
self._store_command_id((command, int(data['id'])))
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command

Loading…
Cancel
Save