Browse Source

Finish impl

pull/9924/head
Soheab_ 10 months ago
parent
commit
5f1affc4f9
  1. 91
      discord/app_commands/tree.py

91
discord/app_commands/tree.py

@ -133,6 +133,13 @@ class CommandTree(Generic[ClientT]):
Note that you can override this on a per command basis. Note that you can override this on a per command basis.
.. versionadded:: 2.4 .. versionadded:: 2.4
store_app_command_ids: :class:`bool`
Whether to store the application command IDs on the tree. These can be used to mention a command.
Defaults to ``False``.
This must be enabled if you want to use :meth:`get_command_mention` or :meth:`get_command_id`.
.. versionadded:: 2.5
""" """
def __init__( def __init__(
@ -142,6 +149,7 @@ class CommandTree(Generic[ClientT]):
fallback_to_global: bool = True, fallback_to_global: bool = True,
allowed_contexts: AppCommandContext = MISSING, allowed_contexts: AppCommandContext = MISSING,
allowed_installs: AppInstallationType = MISSING, allowed_installs: AppInstallationType = MISSING,
store_app_command_ids: bool = False,
): ):
self.client: ClientT = client self.client: ClientT = client
self._http = client.http self._http = client.http
@ -162,8 +170,8 @@ class CommandTree(Generic[ClientT]):
# it's uncommon and N=5 anyway. # it's uncommon and N=5 anyway.
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {}
self._global_command_ids: Dict[str, int] = {} self.store_app_command_ids: bool = store_app_command_ids
self._guild_command_ids: Dict[Tuple[str, int], int] = {} self._command_ids: Dict[Optional[int], Dict[str, int]] = {}
async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand: async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand:
"""|coro| """|coro|
@ -202,10 +210,8 @@ class CommandTree(Generic[ClientT]):
else: else:
command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id) command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id)
res = AppCommand(data=command, state=self._state) self._update_command_ids(command)
# self._store_command_id((res, res.id)) return AppCommand(data=command, state=self._state)
self._store_command_from_data(command)
return res
async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
"""|coro| """|coro|
@ -245,10 +251,8 @@ class CommandTree(Generic[ClientT]):
else: else:
commands = await self._http.get_guild_commands(self.client.application_id, guild.id) commands = await self._http.get_guild_commands(self.client.application_id, guild.id)
res = [AppCommand(data=command, state=self._state) for command in commands] self._update_command_ids(*commands)
# self._store_command_id(*((cmd, cmd.id) for cmd in res)) return [AppCommand(data=command, state=self._state) for command in commands]
self._store_command_from_data(*commands)
return res
def get_command_id( def get_command_id(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
@ -265,7 +269,7 @@ class CommandTree(Generic[ClientT]):
Returns Returns
-------- --------
Optional[:class:`~discord.app_commands.CommandID`] Optional[:class:`int`]
The command ID if found, otherwise ``None``. The command ID if found, otherwise ``None``.
.. note:: .. note::
@ -282,11 +286,11 @@ class CommandTree(Generic[ClientT]):
elif isinstance(command, str): elif isinstance(command, str):
name = command.split()[0] name = command.split()[0]
return self._global_command_ids.get(name) if guild is None else self._guild_command_ids.get((name, guild.id)) return self._command_ids.get(guild.id if guild else None, {}).get(name)
def get_command_mention( def get_command_mention(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> str | None: ) -> Optional[str]:
"""Gets the mention string for a command. """Gets the mention string for a command.
Parameters Parameters
@ -317,6 +321,22 @@ class CommandTree(Generic[ClientT]):
return f'</{full_name}:{command_id}>' return f'</{full_name}:{command_id}>'
def get_command_ids(self, guild: Optional[Snowflake] = None) -> Dict[str, int]:
"""Gets all command IDs for the given guild.
Parameters
-----------
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the command IDs for. If not passed then the global command
IDs are returned instead.
Returns
--------
Dict[:class:`str`, :class:`int`]
A dictionary of command names and their IDs.
"""
return self._command_ids.get(guild.id if guild else None, {})
def copy_global_to(self, *, guild: Snowflake) -> None: def copy_global_to(self, *, guild: Snowflake) -> None:
"""Copies all global commands to the specified guild. """Copies all global commands to the specified guild.
@ -1212,39 +1232,21 @@ class CommandTree(Generic[ClientT]):
raise raise
res = [AppCommand(data=d, state=self._state) for d in data] res = [AppCommand(data=d, state=self._state) for d in data]
# self._store_command_id(*((cmd, cmd.id) for cmd in res)) self._update_command_ids(*data)
self._store_command_from_data(*data)
return res return res
def _store_command_id(self, *commands: Tuple[AppCommand | ContextMenu | Command[Any, ..., Any] | Group, int]) -> None: def _update_command_ids(self, *data: Union[ApplicationCommandInteractionData, ApplicationCommand]) -> None:
for command, command_id in commands: if not self.store_app_command_ids:
if isinstance(command, AppCommand): return
guild_id = command.guild_id
if guild_id is None:
self._global_command_ids[command.name] = command_id
else:
key = (command.name, guild_id)
self._guild_command_ids[key] = command_id
else:
guild_ids = command._guild_ids
name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name
if not guild_ids:
self._global_command_ids[name] = command_id
else:
for guild_id in guild_ids:
key = (name, guild_id)
self._guild_command_ids[key] = command_id
def _store_command_from_data(self, *data: ApplicationCommandInteractionData | ApplicationCommand) -> None:
for d in data: for d in data:
command_id = int(d['id']) command_id: int = int(d['id'])
name = d['name'] name: str = d['name']
guild_id = _get_as_snowflake(d, 'guild_id') guild_id: Optional[int] = _get_as_snowflake(d, 'guild_id')
if guild_id is None: try:
self._global_command_ids[name] = command_id self._command_ids[guild_id][name] = command_id
else: except KeyError:
self._guild_command_ids[(name, guild_id)] = command_id self._command_ids[guild_id] = {name: command_id}
async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None:
command = interaction.command command = interaction.command
@ -1341,8 +1343,7 @@ class CommandTree(Generic[ClientT]):
if ctx_menu is None: if ctx_menu is None:
raise CommandNotFound(name, [], AppCommandType(type)) raise CommandNotFound(name, [], AppCommandType(type))
# self._store_command_id((ctx_menu, int(data['id']))) self._update_command_ids(data)
self._store_command_from_data(data)
resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {}))
@ -1395,7 +1396,7 @@ class CommandTree(Generic[ClientT]):
command, options = self._get_app_command_options(data) command, options = self._get_app_command_options(data)
# self._store_command_id((command, int(data['id']))) # self._store_command_id((command, int(data['id'])))
self._store_command_from_data(data) self._update_command_ids(data)
# Pre-fill the cached slot to prevent re-computation # Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command interaction._cs_command = command

Loading…
Cancel
Save