From d2015c1053efed85faa356ead8aa33b60b9497c0 Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Sun, 18 Aug 2024 22:58:35 +0200 Subject: [PATCH 1/7] Store command ids on tree --- discord/app_commands/tree.py | 108 ++++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index bc0d68ec7..3d89ce5f5 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -161,6 +161,9 @@ class CommandTree(Generic[ClientT]): # it's uncommon and N=5 anyway. self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} + self._global_command_ids: Dict[str, int] = {} + self._guild_command_ids: Dict[Tuple[str, int], int] = {} + async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand: """|coro| @@ -198,7 +201,9 @@ class CommandTree(Generic[ClientT]): else: command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id) - return AppCommand(data=command, state=self._state) + res = AppCommand(data=command, state=self._state) + self._store_command_id((res, res.id)) + return res async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: """|coro| @@ -238,7 +243,79 @@ class CommandTree(Generic[ClientT]): else: commands = await self._http.get_guild_commands(self.client.application_id, guild.id) - return [AppCommand(data=data, state=self._state) for data in commands] + res = [AppCommand(data=command, state=self._state) for command in commands] + self._store_command_id(*((cmd, cmd.id) for cmd in res)) + return res + + def get_command_id( + self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None + ) -> Optional[int]: + """Gets the command ID for a command. + + Parameters + ----------- + name: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`] + The name of the command to get the ID for. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to get the command ID for. If not passed then the global command + ID is fetched instead. + + Returns + -------- + Optional[:class:`~discord.app_commands.CommandID`] + The command ID if found, otherwise ``None``. + + .. note:: + + Group commands will return the ID of the root command. Subcommands do not have their own IDs. + """ + name: Optional[str] = None + + if isinstance(command, AppCommand): + return command.id + elif isinstance(command, (Command, Group)): + name = (command.root_parent or command).name + elif isinstance(command, ContextMenu): + name = command.name + elif isinstance(command, str): + name = command.split()[0] + + return self._global_command_ids.get(name) if guild is None else self._guild_command_ids.get((name, guild.id)) + + def get_command_mention( + self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None + ) -> str | None: + """Gets the mention string for a command. + + Parameters + ----------- + command: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`] + The command to get the mention string for. + + Returns + -------- + Optional[:class:`str`] + The mention string for the command if found, otherwise ``None``. + + .. note:: + + Remember that groups cannot be mentioned, only with a subcommand. + """ + if isinstance(command, AppCommand): + return command.mention + + command_id = self.get_command_id(command, guild=guild) + if command_id is None: + return None + + if isinstance(command, (Command, Group)): + full_name = command.qualified_name + elif isinstance(command, ContextMenu): + full_name = command.name + elif isinstance(command, str): + full_name = command + + return f'' def copy_global_to(self, *, guild: Snowflake) -> None: """Copies all global commands to the specified guild. @@ -1134,7 +1211,29 @@ class CommandTree(Generic[ClientT]): raise CommandSyncFailure(e, commands) from None raise - return [AppCommand(data=d, state=self._state) for d in data] + res = [AppCommand(data=d, state=self._state) for d in data] + self._store_command_id(*((cmd, cmd.id) for cmd in res)) + return res + + def _store_command_id(self, *commands: Tuple[AppCommand | ContextMenu | Command[Any, ..., Any] | Group, int]) -> None: + for command, command_id in commands: + if isinstance(command, AppCommand): + guild_id = command.guild_id + if guild_id is None: + self._global_command_ids[command.name] = command_id + else: + key = (command.name, guild_id) + self._guild_command_ids[key] = command_id + else: + guild_ids = command._guild_ids + name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name + + if not guild_ids: + self._global_command_ids[name] = command_id + else: + for guild_id in guild_ids: + key = (name, guild_id) + self._guild_command_ids[key] = command_id async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: command = interaction.command @@ -1231,6 +1330,8 @@ class CommandTree(Generic[ClientT]): if ctx_menu is None: raise CommandNotFound(name, [], AppCommandType(type)) + self._store_command_id((ctx_menu, int(data['id']))) + resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) # This is annotated as str | int but realistically this will always be str @@ -1281,6 +1382,7 @@ class CommandTree(Generic[ClientT]): return command, options = self._get_app_command_options(data) + self._store_command_id((command, int(data['id']))) # Pre-fill the cached slot to prevent re-computation interaction._cs_command = command From b82fc5b3d32b79ce41f3233268e3cd9ddaf399f9 Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:06:47 +0200 Subject: [PATCH 2/7] Parse from dicts instead --- discord/app_commands/tree.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 3d89ce5f5..ef77b63f3 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -68,6 +68,7 @@ from .._types import ClientT if TYPE_CHECKING: from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption + from ..types.command import ApplicationCommand from ..interactions import Interaction from ..abc import Snowflake from .commands import ContextMenuCallback, CommandCallback, P, T @@ -202,7 +203,8 @@ class CommandTree(Generic[ClientT]): command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id) res = AppCommand(data=command, state=self._state) - self._store_command_id((res, res.id)) + # self._store_command_id((res, res.id)) + self._store_command_from_data(command) return res async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: @@ -244,7 +246,8 @@ class CommandTree(Generic[ClientT]): commands = await self._http.get_guild_commands(self.client.application_id, guild.id) res = [AppCommand(data=command, state=self._state) for command in commands] - self._store_command_id(*((cmd, cmd.id) for cmd in res)) + # self._store_command_id(*((cmd, cmd.id) for cmd in res)) + self._store_command_from_data(*commands) return res def get_command_id( @@ -1212,7 +1215,8 @@ class CommandTree(Generic[ClientT]): raise res = [AppCommand(data=d, state=self._state) for d in data] - self._store_command_id(*((cmd, cmd.id) for cmd in res)) + # self._store_command_id(*((cmd, cmd.id) for cmd in res)) + self._store_command_from_data(*data) return res def _store_command_id(self, *commands: Tuple[AppCommand | ContextMenu | Command[Any, ..., Any] | Group, int]) -> None: @@ -1235,6 +1239,16 @@ class CommandTree(Generic[ClientT]): key = (name, guild_id) self._guild_command_ids[key] = command_id + def _store_command_from_data(self, *data: ApplicationCommandInteractionData | ApplicationCommand) -> None: + for d in data: + command_id = int(d['id']) + name = d['name'] + guild_id = _get_as_snowflake(d, 'guild_id') + if guild_id is None: + self._global_command_ids[name] = command_id + else: + self._guild_command_ids[(name, guild_id)] = command_id + async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: command = interaction.command interaction.command_failed = True @@ -1330,7 +1344,8 @@ class CommandTree(Generic[ClientT]): if ctx_menu is None: raise CommandNotFound(name, [], AppCommandType(type)) - self._store_command_id((ctx_menu, int(data['id']))) + # self._store_command_id((ctx_menu, int(data['id']))) + self._store_command_from_data(data) resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) @@ -1382,7 +1397,8 @@ class CommandTree(Generic[ClientT]): return command, options = self._get_app_command_options(data) - self._store_command_id((command, int(data['id']))) + # self._store_command_id((command, int(data['id']))) + self._store_command_from_data(data) # Pre-fill the cached slot to prevent re-computation interaction._cs_command = command From 886d226641e3b9c05b64e6f2a999a3baf4abfd57 Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:27:41 +0200 Subject: [PATCH 3/7] Combine ContextMenu if-statement --- discord/app_commands/tree.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index ef77b63f3..d41e8b95b 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -276,10 +276,9 @@ class CommandTree(Generic[ClientT]): if isinstance(command, AppCommand): return command.id - elif isinstance(command, (Command, Group)): - name = (command.root_parent or command).name - elif isinstance(command, ContextMenu): - name = command.name + + if isinstance(command, (Command, Group, ContextMenu)): + name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name elif isinstance(command, str): name = command.split()[0] @@ -311,10 +310,8 @@ class CommandTree(Generic[ClientT]): if command_id is None: return None - if isinstance(command, (Command, Group)): + if isinstance(command, (Command, Group, ContextMenu)): full_name = command.qualified_name - elif isinstance(command, ContextMenu): - full_name = command.name elif isinstance(command, str): full_name = command From 5f1affc4f92e287329260d3bf967c503464be358 Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:57:54 +0200 Subject: [PATCH 4/7] Finish impl --- discord/app_commands/tree.py | 93 ++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 46 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index d41e8b95b..f6339de82 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -133,6 +133,13 @@ class CommandTree(Generic[ClientT]): Note that you can override this on a per command basis. .. versionadded:: 2.4 + store_app_command_ids: :class:`bool` + Whether to store the application command IDs on the tree. These can be used to mention a command. + Defaults to ``False``. + + This must be enabled if you want to use :meth:`get_command_mention` or :meth:`get_command_id`. + + .. versionadded:: 2.5 """ def __init__( @@ -142,6 +149,7 @@ class CommandTree(Generic[ClientT]): fallback_to_global: bool = True, allowed_contexts: AppCommandContext = MISSING, allowed_installs: AppInstallationType = MISSING, + store_app_command_ids: bool = False, ): self.client: ClientT = client self._http = client.http @@ -162,8 +170,8 @@ class CommandTree(Generic[ClientT]): # it's uncommon and N=5 anyway. self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} - self._global_command_ids: Dict[str, int] = {} - self._guild_command_ids: Dict[Tuple[str, int], int] = {} + self.store_app_command_ids: bool = store_app_command_ids + self._command_ids: Dict[Optional[int], Dict[str, int]] = {} async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand: """|coro| @@ -202,10 +210,8 @@ class CommandTree(Generic[ClientT]): else: command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id) - res = AppCommand(data=command, state=self._state) - # self._store_command_id((res, res.id)) - self._store_command_from_data(command) - return res + self._update_command_ids(command) + return AppCommand(data=command, state=self._state) async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: """|coro| @@ -245,10 +251,8 @@ class CommandTree(Generic[ClientT]): else: commands = await self._http.get_guild_commands(self.client.application_id, guild.id) - res = [AppCommand(data=command, state=self._state) for command in commands] - # self._store_command_id(*((cmd, cmd.id) for cmd in res)) - self._store_command_from_data(*commands) - return res + self._update_command_ids(*commands) + return [AppCommand(data=command, state=self._state) for command in commands] def get_command_id( self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None @@ -265,7 +269,7 @@ class CommandTree(Generic[ClientT]): Returns -------- - Optional[:class:`~discord.app_commands.CommandID`] + Optional[:class:`int`] The command ID if found, otherwise ``None``. .. note:: @@ -276,17 +280,17 @@ class CommandTree(Generic[ClientT]): if isinstance(command, AppCommand): return command.id - + if isinstance(command, (Command, Group, ContextMenu)): name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name elif isinstance(command, str): name = command.split()[0] - return self._global_command_ids.get(name) if guild is None else self._guild_command_ids.get((name, guild.id)) + return self._command_ids.get(guild.id if guild else None, {}).get(name) def get_command_mention( self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None - ) -> str | None: + ) -> Optional[str]: """Gets the mention string for a command. Parameters @@ -317,6 +321,22 @@ class CommandTree(Generic[ClientT]): return f'' + def get_command_ids(self, guild: Optional[Snowflake] = None) -> Dict[str, int]: + """Gets all command IDs for the given guild. + + Parameters + ----------- + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to get the command IDs for. If not passed then the global command + IDs are returned instead. + + Returns + -------- + Dict[:class:`str`, :class:`int`] + A dictionary of command names and their IDs. + """ + return self._command_ids.get(guild.id if guild else None, {}) + def copy_global_to(self, *, guild: Snowflake) -> None: """Copies all global commands to the specified guild. @@ -1212,39 +1232,21 @@ class CommandTree(Generic[ClientT]): raise res = [AppCommand(data=d, state=self._state) for d in data] - # self._store_command_id(*((cmd, cmd.id) for cmd in res)) - self._store_command_from_data(*data) + self._update_command_ids(*data) return res - def _store_command_id(self, *commands: Tuple[AppCommand | ContextMenu | Command[Any, ..., Any] | Group, int]) -> None: - for command, command_id in commands: - if isinstance(command, AppCommand): - guild_id = command.guild_id - if guild_id is None: - self._global_command_ids[command.name] = command_id - else: - key = (command.name, guild_id) - self._guild_command_ids[key] = command_id - else: - guild_ids = command._guild_ids - name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name - - if not guild_ids: - self._global_command_ids[name] = command_id - else: - for guild_id in guild_ids: - key = (name, guild_id) - self._guild_command_ids[key] = command_id + def _update_command_ids(self, *data: Union[ApplicationCommandInteractionData, ApplicationCommand]) -> None: + if not self.store_app_command_ids: + return - def _store_command_from_data(self, *data: ApplicationCommandInteractionData | ApplicationCommand) -> None: for d in data: - command_id = int(d['id']) - name = d['name'] - guild_id = _get_as_snowflake(d, 'guild_id') - if guild_id is None: - self._global_command_ids[name] = command_id - else: - self._guild_command_ids[(name, guild_id)] = command_id + command_id: int = int(d['id']) + name: str = d['name'] + guild_id: Optional[int] = _get_as_snowflake(d, 'guild_id') + try: + self._command_ids[guild_id][name] = command_id + except KeyError: + self._command_ids[guild_id] = {name: command_id} async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: command = interaction.command @@ -1341,8 +1343,7 @@ class CommandTree(Generic[ClientT]): if ctx_menu is None: raise CommandNotFound(name, [], AppCommandType(type)) - # self._store_command_id((ctx_menu, int(data['id']))) - self._store_command_from_data(data) + self._update_command_ids(data) resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) @@ -1395,7 +1396,7 @@ class CommandTree(Generic[ClientT]): command, options = self._get_app_command_options(data) # self._store_command_id((command, int(data['id']))) - self._store_command_from_data(data) + self._update_command_ids(data) # Pre-fill the cached slot to prevent re-computation interaction._cs_command = command From b4431f6c27fb692436410920dae7e65d7cb3bd1c Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:10:58 +0200 Subject: [PATCH 5/7] Update ids on delete & edit too --- discord/app_commands/models.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/discord/app_commands/models.py b/discord/app_commands/models.py index e8a96784b..b7d7f6b85 100644 --- a/discord/app_commands/models.py +++ b/discord/app_commands/models.py @@ -309,6 +309,10 @@ class AppCommand(Hashable): self.id, ) + tree = self._state._command_tree + if tree: + tree._command_ids.get(self.guild_id, {}).pop(self.name, None) + async def edit( self, *, @@ -392,6 +396,11 @@ class AppCommand(Hashable): self.id, payload, ) + + tree = self._state._command_tree + if tree: + tree._update_command_ids(data) + return AppCommand(data=data, state=state) async def fetch_permissions(self, guild: Snowflake) -> GuildAppCommandPermissions: From b2a1cb22bbc855ffb09f154670b2868404b8daaa Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:11:10 +0200 Subject: [PATCH 6/7] Update ids once in _call --- discord/app_commands/tree.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index f6339de82..c87996dc5 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -1343,8 +1343,6 @@ class CommandTree(Generic[ClientT]): if ctx_menu is None: raise CommandNotFound(name, [], AppCommandType(type)) - self._update_command_ids(data) - resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) # This is annotated as str | int but realistically this will always be str @@ -1388,6 +1386,7 @@ class CommandTree(Generic[ClientT]): return data: ApplicationCommandInteractionData = interaction.data # type: ignore + self._update_command_ids(data) type = data.get('type', 1) if type != 1: # Context menu command... @@ -1395,8 +1394,6 @@ class CommandTree(Generic[ClientT]): return command, options = self._get_app_command_options(data) - # self._store_command_id((command, int(data['id']))) - self._update_command_ids(data) # Pre-fill the cached slot to prevent re-computation interaction._cs_command = command From 5596d736fd7e273197fb72d1286b32c1f9dc5462 Mon Sep 17 00:00:00 2001 From: Soheab_ <33902984+Soheab@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:15:00 +0200 Subject: [PATCH 7/7] Update tree.py --- discord/app_commands/tree.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index c87996dc5..485da7af9 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -1231,9 +1231,8 @@ class CommandTree(Generic[ClientT]): raise CommandSyncFailure(e, commands) from None raise - res = [AppCommand(data=d, state=self._state) for d in data] self._update_command_ids(*data) - return res + return [AppCommand(data=d, state=self._state) for d in data] def _update_command_ids(self, *data: Union[ApplicationCommandInteractionData, ApplicationCommand]) -> None: if not self.store_app_command_ids: