From e6a87e0782b453ea108a0f406ee0cc3a2f90aa29 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 7 Mar 2022 21:52:58 -0500 Subject: [PATCH] Add support for adding app commands locally to many guilds This affects the context_menu and command decorators as well. Removing and syncing do not support multiple guild IDs. --- discord/app_commands/tree.py | 113 +++++++++++++++++++++++++++-------- 1 file changed, 88 insertions(+), 25 deletions(-) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index b4582b6bf..dadcf3dad 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -26,7 +26,7 @@ from __future__ import annotations import inspect import sys import traceback -from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, overload +from typing import Callable, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, Set, Tuple, TypeVar, Union, overload from .namespace import Namespace, ResolveKey @@ -54,6 +54,23 @@ __all__ = ('CommandTree',) ClientT = TypeVar('ClientT', bound='Client') +def _retrieve_guild_ids(guild: Optional[Snowflake] = MISSING, guilds: List[Snowflake] = MISSING) -> Optional[Set[int]]: + if guild is not MISSING and guilds is not MISSING: + raise TypeError('cannot mix guild and guilds keyword arguments') + + # guilds=[] or guilds=[...] or no args at all + if guild is MISSING: + if not guilds: + return None + return {g.id for g in guilds} + + # At this point it should be... + # guild=None or guild=Object + if guild is None: + return None + return {guild.id} + + class CommandTree(Generic[ClientT]): """Represents a container that holds application command information. @@ -121,7 +138,8 @@ class CommandTree(Generic[ClientT]): command: Union[Command, ContextMenu, Group], /, *, - guild: Optional[Snowflake] = None, + guild: Optional[Snowflake] = MISSING, + guilds: List[Snowflake] = MISSING, override: bool = False, ): """Adds an application command to the tree. @@ -138,6 +156,10 @@ class CommandTree(Generic[ClientT]): guild: Optional[:class:`~discord.abc.Snowflake`] The guild to add the command to. If not given then it becomes a global command instead. + guilds: List[:class:`~discord.abc.Snowflake`] + The list of guilds to add the command to. This cannot be mixed + with the ``guild`` parameter. If no guilds are given at all + then it becomes a global command instead. override: :class:`bool` Whether to override a command with the same name. If ``False`` an exception is raised. Default is ``False``. @@ -148,23 +170,44 @@ class CommandTree(Generic[ClientT]): The command was already registered and no override was specified. TypeError The application command passed is not a valid application command. + Or, ``guild`` and ``guilds`` were both given. ValueError The maximum number of commands was reached globally or for that guild. This is currently 100 for slash commands and 5 for context menu commands. """ + guild_ids = _retrieve_guild_ids(guild, guilds) if isinstance(command, ContextMenu): - guild_id = None if guild is None else guild.id type = command.type.value - key = (command.name, guild_id, type) - found = key in self._context_menus - if found and not override: - raise CommandAlreadyRegistered(command.name, guild_id) - - total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type) - if total + found > 5: - raise ValueError('maximum number of context menu commands exceeded (5)') - self._context_menus[key] = command + name = command.name + + def _context_menu_add_helper( + guild_id: Optional[int], + data: Dict[Tuple[str, Optional[int], int], ContextMenu], + name: str = name, + type: int = type, + ) -> None: + key = (name, guild_id, type) + found = key in self._context_menus + if found and not override: + raise CommandAlreadyRegistered(name, guild_id) + + total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type) + if total + found > 5: + raise ValueError('maximum number of context menu commands exceeded (5)') + data[key] = command + + if guild_ids is None: + _context_menu_add_helper(None, self._context_menus) + else: + current: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} + for guild_id in guild_ids: + _context_menu_add_helper(guild_id, current) + + # Update at the end in order to make sure the update is atomic. + # An error during addition could end up making the context menu mapping + # have a partial state + self._context_menus.update(current) return elif not isinstance(command, (Command, Group)): raise TypeError(f'Expected a application command, received {command.__class__!r} instead') @@ -173,20 +216,27 @@ class CommandTree(Generic[ClientT]): root = command.root_parent or command name = root.name - if guild is not None: - commands = self._guild_commands.setdefault(guild.id, {}) - found = name in commands - if found and not override: - raise CommandAlreadyRegistered(name, guild.id) - if len(commands) + found > 100: - raise ValueError('maximum number of slash commands exceeded (100)') - commands[name] = root + if guild_ids is not None: + # Validate that the command can be added first, before actually + # adding it into the mapping. This ensures atomicity. + for guild_id in guild_ids: + commands = self._guild_commands.get(guild_id, {}) + found = name in commands + if found and not override: + raise CommandAlreadyRegistered(name, guild_id) + if len(commands) + found > 100: + raise ValueError(f'maximum number of slash commands exceeded (100) for guild_id {guild_id}') + + # Actually add the command now that it has been verified to be okay. + for guild_id in guild_ids: + commands = self._guild_commands.setdefault(guild_id, {}) + commands[name] = root else: found = name in self._global_commands if found and not override: raise CommandAlreadyRegistered(name, None) if len(self._global_commands) + found > 100: - raise ValueError('maximum number of slash commands exceeded (100)') + raise ValueError('maximum number of global slash commands exceeded (100)') self._global_commands[name] = root @overload @@ -459,7 +509,8 @@ class CommandTree(Generic[ClientT]): *, name: str = MISSING, description: str = MISSING, - guild: Optional[Snowflake] = None, + guild: Optional[Snowflake] = MISSING, + guilds: List[Snowflake] = MISSING, ) -> Callable[[CommandCallback[Group, P, T]], Command[Group, P, T]]: """Creates an application command directly under this tree. @@ -475,6 +526,10 @@ class CommandTree(Generic[ClientT]): guild: Optional[:class:`~discord.abc.Snowflake`] The guild to add the command to. If not given then it becomes a global command instead. + guilds: List[:class:`~discord.abc.Snowflake`] + The list of guilds to add the command to. This cannot be mixed + with the ``guild`` parameter. If no guilds are given at all + then it becomes a global command instead. """ def decorator(func: CommandCallback[Group, P, T]) -> Command[Group, P, T]: @@ -495,13 +550,17 @@ class CommandTree(Generic[ClientT]): callback=func, parent=None, ) - self.add_command(command, guild=guild) + self.add_command(command, guild=guild, guilds=guilds) return command return decorator def context_menu( - self, *, name: str = MISSING, guild: Optional[Snowflake] = None + self, + *, + name: str = MISSING, + guild: Optional[Snowflake] = MISSING, + guilds: List[Snowflake] = MISSING, ) -> Callable[[ContextMenuCallback], ContextMenu]: """Creates a application command context menu from a regular function directly under this tree. @@ -531,6 +590,10 @@ class CommandTree(Generic[ClientT]): guild: Optional[:class:`~discord.abc.Snowflake`] The guild to add the command to. If not given then it becomes a global command instead. + guilds: List[:class:`~discord.abc.Snowflake`] + The list of guilds to add the command to. This cannot be mixed + with the ``guild`` parameter. If no guilds are given at all + then it becomes a global command instead. """ def decorator(func: ContextMenuCallback) -> ContextMenu: @@ -538,7 +601,7 @@ class CommandTree(Generic[ClientT]): raise TypeError('context menu function must be a coroutine function') context_menu = ContextMenu._from_decorator(func, name=name) - self.add_command(context_menu, guild=guild) + self.add_command(context_menu, guild=guild, guilds=guilds) return context_menu return decorator