From 446bfa78b03cff8a837f510c4a7d8255c2d76b7d Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 9 Mar 2022 19:48:51 -0500 Subject: [PATCH] [commands] Allow Cog and app_commands interopability This changeset allows app commands defined inside Cog to work as expected. Likewise, by deriving app_commands.Group and Cog you can make the cog function as a top level command on Discord. --- discord/app_commands/commands.py | 19 ++++---- discord/ext/commands/bot.py | 81 ++++++++++++++++++++++++++++++-- discord/ext/commands/cog.py | 78 ++++++++++++++++++++++++++---- 3 files changed, 157 insertions(+), 21 deletions(-) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 5311fb147..1b0542a0e 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -61,6 +61,11 @@ if TYPE_CHECKING: from .namespace import Namespace from .models import ChoiceT + # Generally, these two libraries are supposed to be separate from each other. + # However, for type hinting purposes it's unfortunately necessary for one to + # reference the other to prevent type checking errors in callbacks + from discord.ext.commands import Cog + __all__ = ( 'Command', 'ContextMenu', @@ -79,7 +84,7 @@ else: P = TypeVar('P') T = TypeVar('T') -GroupT = TypeVar('GroupT', bound='Group') +GroupT = TypeVar('GroupT', bound='Union[Group, Cog]') Coro = Coroutine[Any, Any, T] Error = Union[ Callable[[GroupT, Interaction, AppCommandError], Coro[Any]], @@ -628,15 +633,14 @@ class Group: """ __discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = [] + __discord_app_commands_skip_init_binding__: bool = False __discord_app_commands_group_name__: str = MISSING __discord_app_commands_group_description__: str = MISSING def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None: if not cls.__discord_app_commands_group_children__: cls.__discord_app_commands_group_children__ = children = [ - member - for member in cls.__dict__.values() - if isinstance(member, (Group, Command)) and member.parent is None + member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None ] found = set() @@ -661,7 +665,6 @@ class Group: else: cls.__discord_app_commands_group_description__ = description - def __init__( self, *, @@ -683,10 +686,10 @@ class Group: self._children: Dict[str, Union[Command, Group]] = {} for child in self.__discord_app_commands_group_children__: - child = child._copy_with_binding(self) + child = child._copy_with_binding(self) if not cls.__discord_app_commands_skip_init_binding__ else child child.parent = self self._children[child.name] = child - if child._attr: + if child._attr and not cls.__discord_app_commands_skip_init_binding__: setattr(self, child._attr, child) if parent is not None and parent.parent is not None: @@ -695,7 +698,7 @@ class Group: def __set_name__(self, owner: Type[Any], name: str) -> None: self._attr = name - def _copy_with_binding(self, binding: Group) -> Group: + def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group: cls = self.__class__ copy = cls.__new__(cls) copy.name = self.name diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 3746ce118..231583cc9 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -36,6 +36,8 @@ import types from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload import discord +from discord import app_commands +from discord.app_commands.tree import _retrieve_guild_ids from .core import GroupMixin from .view import StringView @@ -50,7 +52,7 @@ if TYPE_CHECKING: import importlib.machinery from discord.message import Message - from discord.abc import User + from discord.abc import User, Snowflake from ._types import ( Check, CoroFunc, @@ -135,6 +137,8 @@ class BotBase(GroupMixin): super().__init__(**options) self.command_prefix = command_prefix self.extra_events: Dict[str, List[CoroFunc]] = {} + # Self doesn't have the ClientT bound, but since this is a mixin it technically does + self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} self._checks: List[Check] = [] @@ -529,11 +533,22 @@ class BotBase(GroupMixin): # cogs - def add_cog(self, cog: Cog, /, *, override: bool = False) -> None: + def add_cog( + self, + cog: Cog, + /, + *, + override: bool = False, + guild: Optional[Snowflake] = MISSING, + guilds: List[Snowflake] = MISSING, + ) -> None: """Adds a "cog" to the bot. A cog is a class that has its own event listeners and commands. + If the cog is a :class:`.app_commands.Group` then it is added to + the bot's :class:`~discord.app_commands.CommandTree` as well. + .. versionchanged:: 2.0 :exc:`.ClientException` is raised when a cog with the same name @@ -551,6 +566,19 @@ class BotBase(GroupMixin): If a previously loaded cog with the same name should be ejected instead of raising an error. + .. versionadded:: 2.0 + guild: Optional[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guild where the cog group would be added to. If not given then + it becomes a global command instead. + + .. versionadded:: 2.0 + guilds: List[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guilds where the cog group would be added to. If not given then + it becomes a global command instead. Cannot be mixed with + ``guild``. + .. versionadded:: 2.0 Raises @@ -572,7 +600,10 @@ class BotBase(GroupMixin): if existing is not None: if not override: raise discord.ClientException(f'Cog named {cog_name!r} already loaded') - self.remove_cog(cog_name) + self.remove_cog(cog_name, guild=guild, guilds=guilds) + + if isinstance(cog, app_commands.Group): + self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds) cog = cog._inject(self) self.__cogs[cog_name] = cog @@ -600,7 +631,13 @@ class BotBase(GroupMixin): """ return self.__cogs.get(name) - def remove_cog(self, name: str, /) -> Optional[Cog]: + def remove_cog( + self, + name: str, + /, + guild: Optional[Snowflake] = MISSING, + guilds: List[Snowflake] = MISSING, + ) -> Optional[Cog]: """Removes a cog from the bot and returns it. All registered commands and event listeners that the @@ -616,6 +653,19 @@ class BotBase(GroupMixin): ----------- name: :class:`str` The name of the cog to remove. + guild: Optional[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guild where the cog group would be removed from. If not given then + a global command is removed instead instead. + + .. versionadded:: 2.0 + guilds: List[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guilds where the cog group would be removed from. If not given then + a global command is removed instead instead. Cannot be mixed with + ``guild``. + + .. versionadded:: 2.0 Returns ------- @@ -630,6 +680,15 @@ class BotBase(GroupMixin): help_command = self._help_command if help_command and help_command.cog is cog: help_command.cog = None + + if isinstance(cog, app_commands.Group): + guild_ids = _retrieve_guild_ids(cog, guild, guilds) + if guild_ids is None: + self.__tree.remove_command(name) + else: + for guild_id in guild_ids: + self.__tree.remove_command(name, guild=discord.Object(guild_id)) + cog._eject(self) return cog @@ -894,6 +953,20 @@ class BotBase(GroupMixin): else: self._help_command = None + # application command interop + + # As mentioned above, this is a mixin so the Self type hint fails here. + # However, since the only classes that can use this are subclasses of Client + # anyway, then this is sound. + @property + def tree(self) -> app_commands.CommandTree[Self]: # type: ignore + """:class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands + in this bot. + + .. versionadded:: 2.0 + """ + return self.__tree + # command processing async def get_prefix(self, message: Message) -> Union[List[str], str]: diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index fa8519d35..8d88bf4d6 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -24,14 +24,15 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import inspect -import discord.utils +import discord +from discord import app_commands -from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type +from typing import Any, Callable, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, Type from ._types import _BaseCommand if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeGuard from .bot import BotBase from .context import Context @@ -110,19 +111,33 @@ class CogMeta(type): __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[Command] + __cog_is_app_commands_group__: bool + __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]] __cog_listeners__: List[Tuple[str, str]] def __new__(cls, *args: Any, **kwargs: Any) -> Self: name, bases, attrs = args - attrs['__cog_name__'] = kwargs.pop('name', name) + attrs['__cog_name__'] = kwargs.get('name', name) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) + attrs['__cog_is_app_commands_group__'] = is_parent = app_commands.Group in bases - description = kwargs.pop('description', None) + description = kwargs.get('description', None) if description is None: description = inspect.cleandoc(attrs.get('__doc__', '')) attrs['__cog_description__'] = description + if is_parent: + attrs['__discord_app_commands_skip_init_binding__'] = True + # This is hacky, but it signals the Group not to process this info. + # It's overridden later. + attrs['__discord_app_commands_group_children__'] = True + else: + # Remove the extraneous keyword arguments we're using + kwargs.pop('name', None) + kwargs.pop('description', None) + commands = {} + cog_app_commands = {} listeners = {} no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' @@ -143,6 +158,8 @@ class CogMeta(type): if elem.startswith(('cog_', 'bot_')): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value + elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None: + cog_app_commands[elem] = value elif inspect.iscoroutinefunction(value): try: getattr(value, '__cog_listener__') @@ -154,6 +171,13 @@ class CogMeta(type): listeners[elem] = value new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ + new_cls.__cog_app_commands__ = list(cog_app_commands.values()) + + if is_parent: + # Prefill the app commands for the Group as well.. + # The type checker doesn't like runtime attribute modification and this one's + # optional so it can't be cheesed. + new_cls.__discord_app_commands_group_children__ = cog_app_commands # type: ignore listeners_as_list = [] for listener in listeners.values(): @@ -189,10 +213,11 @@ class Cog(metaclass=CogMeta): are equally valid here. """ - __cog_name__: ClassVar[str] - __cog_settings__: ClassVar[Dict[str, Any]] - __cog_commands__: ClassVar[List[Command[Self, ..., Any]]] - __cog_listeners__: ClassVar[List[Tuple[str, str]]] + __cog_name__: str + __cog_settings__: Dict[str, Any] + __cog_commands__: List[Command[Self, ..., Any]] + __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] + __cog_listeners__: List[Tuple[str, str]] def __new__(cls, *args: Any, **kwargs: Any) -> Self: # For issue 426, we need to store a copy of the command objects @@ -219,6 +244,25 @@ class Cog(metaclass=CogMeta): parent.remove_command(command.name) # type: ignore parent.add_command(command) # type: ignore + # Register the application commands + children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = [] + for command in cls.__cog_app_commands__: + copy = command._copy_with_binding(self) + + if cls.__cog_is_app_commands_group__: + # Type checker doesn't understand this type of narrowing. + # Not even with TypeGuard somehow. + copy.parent = self # type: ignore + + children.append(copy) + if command._attr: + setattr(self, command._attr, copy) + + self.__cog_app_commands__ = children + if cls.__cog_is_app_commands_group__: + # Dynamic attribute setting + self.__discord_app_commands_group_children__ = children # type: ignore + return self def get_commands(self) -> List[Command[Self, ..., Any]]: @@ -452,6 +496,12 @@ class Cog(metaclass=CogMeta): for name, method_name in self.__cog_listeners__: bot.add_listener(getattr(self, method_name), name) + # Only do this if these are "top level" commands + if not cls.__cog_is_app_commands_group__: + for command in self.__cog_app_commands__: + # This is already atomic + bot.tree.add_command(command) + return self def _eject(self, bot: BotBase) -> None: @@ -462,6 +512,16 @@ class Cog(metaclass=CogMeta): if command.parent is None: bot.remove_command(command.name) + if not cls.__cog_is_app_commands_group__: + for command in self.__cog_app_commands__: + try: + guild_ids = command.__discord_app_commands_default_guilds__ + except AttributeError: + bot.tree.remove_command(command.name) + else: + for guild_id in guild_ids: + bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id)) + for name, method_name in self.__cog_listeners__: bot.remove_listener(getattr(self, method_name), name)