Browse Source

[commands] Allow Cog and app_commands interopability

This changeset allows app commands defined inside Cog to work as
expected. Likewise, by deriving app_commands.Group and Cog you can
make the cog function as a top level command on Discord.
pull/7611/head
Rapptz 3 years ago
parent
commit
446bfa78b0
  1. 19
      discord/app_commands/commands.py
  2. 81
      discord/ext/commands/bot.py
  3. 78
      discord/ext/commands/cog.py

19
discord/app_commands/commands.py

@ -61,6 +61,11 @@ if TYPE_CHECKING:
from .namespace import Namespace
from .models import ChoiceT
# Generally, these two libraries are supposed to be separate from each other.
# However, for type hinting purposes it's unfortunately necessary for one to
# reference the other to prevent type checking errors in callbacks
from discord.ext.commands import Cog
__all__ = (
'Command',
'ContextMenu',
@ -79,7 +84,7 @@ else:
P = TypeVar('P')
T = TypeVar('T')
GroupT = TypeVar('GroupT', bound='Group')
GroupT = TypeVar('GroupT', bound='Union[Group, Cog]')
Coro = Coroutine[Any, Any, T]
Error = Union[
Callable[[GroupT, Interaction, AppCommandError], Coro[Any]],
@ -628,15 +633,14 @@ class Group:
"""
__discord_app_commands_group_children__: ClassVar[List[Union[Command, Group]]] = []
__discord_app_commands_skip_init_binding__: bool = False
__discord_app_commands_group_name__: str = MISSING
__discord_app_commands_group_description__: str = MISSING
def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None:
if not cls.__discord_app_commands_group_children__:
cls.__discord_app_commands_group_children__ = children = [
member
for member in cls.__dict__.values()
if isinstance(member, (Group, Command)) and member.parent is None
member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None
]
found = set()
@ -661,7 +665,6 @@ class Group:
else:
cls.__discord_app_commands_group_description__ = description
def __init__(
self,
*,
@ -683,10 +686,10 @@ class Group:
self._children: Dict[str, Union[Command, Group]] = {}
for child in self.__discord_app_commands_group_children__:
child = child._copy_with_binding(self)
child = child._copy_with_binding(self) if not cls.__discord_app_commands_skip_init_binding__ else child
child.parent = self
self._children[child.name] = child
if child._attr:
if child._attr and not cls.__discord_app_commands_skip_init_binding__:
setattr(self, child._attr, child)
if parent is not None and parent.parent is not None:
@ -695,7 +698,7 @@ class Group:
def __set_name__(self, owner: Type[Any], name: str) -> None:
self._attr = name
def _copy_with_binding(self, binding: Group) -> Group:
def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group:
cls = self.__class__
copy = cls.__new__(cls)
copy.name = self.name

81
discord/ext/commands/bot.py

@ -36,6 +36,8 @@ import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
import discord
from discord import app_commands
from discord.app_commands.tree import _retrieve_guild_ids
from .core import GroupMixin
from .view import StringView
@ -50,7 +52,7 @@ if TYPE_CHECKING:
import importlib.machinery
from discord.message import Message
from discord.abc import User
from discord.abc import User, Snowflake
from ._types import (
Check,
CoroFunc,
@ -135,6 +137,8 @@ class BotBase(GroupMixin):
super().__init__(**options)
self.command_prefix = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
@ -529,11 +533,22 @@ class BotBase(GroupMixin):
# cogs
def add_cog(self, cog: Cog, /, *, override: bool = False) -> None:
def add_cog(
self,
cog: Cog,
/,
*,
override: bool = False,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> None:
"""Adds a "cog" to the bot.
A cog is a class that has its own event listeners and commands.
If the cog is a :class:`.app_commands.Group` then it is added to
the bot's :class:`~discord.app_commands.CommandTree` as well.
.. versionchanged:: 2.0
:exc:`.ClientException` is raised when a cog with the same name
@ -551,6 +566,19 @@ class BotBase(GroupMixin):
If a previously loaded cog with the same name should be ejected
instead of raising an error.
.. versionadded:: 2.0
guild: Optional[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guild where the cog group would be added to. If not given then
it becomes a global command instead.
.. versionadded:: 2.0
guilds: List[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guilds where the cog group would be added to. If not given then
it becomes a global command instead. Cannot be mixed with
``guild``.
.. versionadded:: 2.0
Raises
@ -572,7 +600,10 @@ class BotBase(GroupMixin):
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
self.remove_cog(cog_name)
self.remove_cog(cog_name, guild=guild, guilds=guilds)
if isinstance(cog, app_commands.Group):
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
cog = cog._inject(self)
self.__cogs[cog_name] = cog
@ -600,7 +631,13 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
def remove_cog(self, name: str, /) -> Optional[Cog]:
def remove_cog(
self,
name: str,
/,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
All registered commands and event listeners that the
@ -616,6 +653,19 @@ class BotBase(GroupMixin):
-----------
name: :class:`str`
The name of the cog to remove.
guild: Optional[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guild where the cog group would be removed from. If not given then
a global command is removed instead instead.
.. versionadded:: 2.0
guilds: List[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guilds where the cog group would be removed from. If not given then
a global command is removed instead instead. Cannot be mixed with
``guild``.
.. versionadded:: 2.0
Returns
-------
@ -630,6 +680,15 @@ class BotBase(GroupMixin):
help_command = self._help_command
if help_command and help_command.cog is cog:
help_command.cog = None
if isinstance(cog, app_commands.Group):
guild_ids = _retrieve_guild_ids(cog, guild, guilds)
if guild_ids is None:
self.__tree.remove_command(name)
else:
for guild_id in guild_ids:
self.__tree.remove_command(name, guild=discord.Object(guild_id))
cog._eject(self)
return cog
@ -894,6 +953,20 @@ class BotBase(GroupMixin):
else:
self._help_command = None
# application command interop
# As mentioned above, this is a mixin so the Self type hint fails here.
# However, since the only classes that can use this are subclasses of Client
# anyway, then this is sound.
@property
def tree(self) -> app_commands.CommandTree[Self]: # type: ignore
""":class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands
in this bot.
.. versionadded:: 2.0
"""
return self.__tree
# command processing
async def get_prefix(self, message: Message) -> Union[List[str], str]:

78
discord/ext/commands/cog.py

@ -24,14 +24,15 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import inspect
import discord.utils
import discord
from discord import app_commands
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from typing import Any, Callable, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union, Type
from ._types import _BaseCommand
if TYPE_CHECKING:
from typing_extensions import Self
from typing_extensions import Self, TypeGuard
from .bot import BotBase
from .context import Context
@ -110,19 +111,33 @@ class CogMeta(type):
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_is_app_commands_group__: bool
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_name__'] = kwargs.get('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
attrs['__cog_is_app_commands_group__'] = is_parent = app_commands.Group in bases
description = kwargs.pop('description', None)
description = kwargs.get('description', None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
if is_parent:
attrs['__discord_app_commands_skip_init_binding__'] = True
# This is hacky, but it signals the Group not to process this info.
# It's overridden later.
attrs['__discord_app_commands_group_children__'] = True
else:
# Remove the extraneous keyword arguments we're using
kwargs.pop('name', None)
kwargs.pop('description', None)
commands = {}
cog_app_commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
@ -143,6 +158,8 @@ class CogMeta(type):
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None:
cog_app_commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
@ -154,6 +171,13 @@ class CogMeta(type):
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_app_commands__ = list(cog_app_commands.values())
if is_parent:
# Prefill the app commands for the Group as well..
# The type checker doesn't like runtime attribute modification and this one's
# optional so it can't be cheesed.
new_cls.__discord_app_commands_group_children__ = cog_app_commands # type: ignore
listeners_as_list = []
for listener in listeners.values():
@ -189,10 +213,11 @@ class Cog(metaclass=CogMeta):
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command[Self, ..., Any]]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command[Self, ..., Any]]
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# For issue 426, we need to store a copy of the command objects
@ -219,6 +244,25 @@ class Cog(metaclass=CogMeta):
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
# Register the application commands
children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = []
for command in cls.__cog_app_commands__:
copy = command._copy_with_binding(self)
if cls.__cog_is_app_commands_group__:
# Type checker doesn't understand this type of narrowing.
# Not even with TypeGuard somehow.
copy.parent = self # type: ignore
children.append(copy)
if command._attr:
setattr(self, command._attr, copy)
self.__cog_app_commands__ = children
if cls.__cog_is_app_commands_group__:
# Dynamic attribute setting
self.__discord_app_commands_group_children__ = children # type: ignore
return self
def get_commands(self) -> List[Command[Self, ..., Any]]:
@ -452,6 +496,12 @@ class Cog(metaclass=CogMeta):
for name, method_name in self.__cog_listeners__:
bot.add_listener(getattr(self, method_name), name)
# Only do this if these are "top level" commands
if not cls.__cog_is_app_commands_group__:
for command in self.__cog_app_commands__:
# This is already atomic
bot.tree.add_command(command)
return self
def _eject(self, bot: BotBase) -> None:
@ -462,6 +512,16 @@ class Cog(metaclass=CogMeta):
if command.parent is None:
bot.remove_command(command.name)
if not cls.__cog_is_app_commands_group__:
for command in self.__cog_app_commands__:
try:
guild_ids = command.__discord_app_commands_default_guilds__
except AttributeError:
bot.tree.remove_command(command.name)
else:
for guild_id in guild_ids:
bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id))
for name, method_name in self.__cog_listeners__:
bot.remove_listener(getattr(self, method_name), name)

Loading…
Cancel
Save