From 0a8065606aa4a02307a75ab8c58556e5018bb989 Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Wed, 30 Mar 2022 18:12:39 -0500 Subject: [PATCH] Update parent reference of Group children --- discord/app_commands/commands.py | 66 +++++-- discord/ext/commands/cog.py | 10 +- tests/test_app_commands_group.py | 317 +++++++++++++++++++++++++++++++ 3 files changed, 375 insertions(+), 18 deletions(-) create mode 100644 tests/test_app_commands_group.py diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 4b2133336..3a1b4467a 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -34,6 +34,7 @@ from typing import ( Generator, Generic, List, + MutableMapping, Optional, Set, TYPE_CHECKING, @@ -87,7 +88,7 @@ else: P = TypeVar('P') T = TypeVar('T') -GroupT = TypeVar('GroupT', bound='Union[Group, Cog]') +GroupT = TypeVar('GroupT', bound='Binding') Coro = Coroutine[Any, Any, T] UnboundError = Callable[['Interaction', AppCommandError], Coro[Any]] Error = Union[ @@ -95,6 +96,7 @@ Error = Union[ UnboundError, ] Check = Callable[['Interaction'], Union[bool, Coro[bool]]] +Binding = Union['Group', 'Cog'] if TYPE_CHECKING: @@ -442,7 +444,16 @@ class Command(Generic[GroupT, P, T]): """:ref:`coroutine `: The coroutine that is executed when the command is called.""" return self._callback - def _copy_with_binding(self, binding: GroupT) -> Command: + def _copy_with( + self, + *, + parent: Optional[Group], + binding: GroupT, + bindings: MutableMapping[GroupT, GroupT] = MISSING, + set_on_binding: bool = True, + ) -> Command: + bindings = {} if bindings is MISSING else bindings + cls = self.__class__ copy = cls.__new__(cls) copy.name = self.name @@ -451,11 +462,15 @@ class Command(Generic[GroupT, P, T]): copy.description = self.description copy._attr = self._attr copy._callback = self._callback - copy.parent = self.parent copy.on_error = self.on_error copy._params = self._params.copy() copy.module = self.module - copy.binding = binding + copy.parent = parent + copy.binding = bindings.get(self.binding) if self.binding is not None else binding + + if copy._attr and set_on_binding: + setattr(copy.binding, copy._attr, copy) + return copy def to_dict(self) -> Dict[str, Any]: @@ -969,12 +984,19 @@ class Group: self._children: Dict[str, Union[Command, Group]] = {} + bindings: Dict[Group, Group] = {} + for child in self.__discord_app_commands_group_children__: - child.parent = self - child = child._copy_with_binding(self) if not cls.__discord_app_commands_skip_init_binding__ else child - self._children[child.name] = child - if child._attr and not cls.__discord_app_commands_skip_init_binding__: - setattr(self, child._attr, child) + # commands and groups created directly in this class (no parent) + copy = ( + child._copy_with(parent=self, binding=self, bindings=bindings, set_on_binding=False) + if not cls.__discord_app_commands_skip_init_binding__ + else child + ) + + self._children[copy.name] = copy + if copy._attr and not cls.__discord_app_commands_skip_init_binding__: + setattr(self, copy._attr, copy) if parent is not None and parent.parent is not None: raise ValueError('groups can only be nested at most one level') @@ -983,16 +1005,36 @@ class Group: self._attr = name self.module = owner.__module__ - def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group: + def _copy_with( + self, + *, + parent: Optional[Group], + binding: Binding, + bindings: MutableMapping[Group, Group] = MISSING, + set_on_binding: bool = True, + ) -> Group: + bindings = {} if bindings is MISSING else bindings + cls = self.__class__ copy = cls.__new__(cls) copy.name = self.name copy._guild_ids = self._guild_ids copy.description = self.description - copy.parent = self.parent + copy.parent = parent copy.module = self.module copy._attr = self._attr - copy._children = {child.name: child._copy_with_binding(binding) for child in self._children.values()} + copy._children = {} + + bindings[self] = copy + + for child in self._children.values(): + child_copy = child._copy_with(parent=copy, binding=binding, bindings=bindings) + child_copy.parent = copy + copy._children[child_copy.name] = child_copy + + if copy._attr and set_on_binding: + setattr(parent or binding, copy._attr, copy) + return copy def to_dict(self) -> Dict[str, Any]: diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 1cd4f8895..6e2ec38a8 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -250,16 +250,14 @@ class Cog(metaclass=CogMeta): # Register the application commands children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = [] for command in cls.__cog_app_commands__: - if cls.__cog_is_app_commands_group__: + copy = command._copy_with( # Type checker doesn't understand this type of narrowing. # Not even with TypeGuard somehow. - command.parent = self # type: ignore - - copy = command._copy_with_binding(self) + parent=self if cls.__cog_is_app_commands_group__ else None, # type: ignore + binding=self, + ) children.append(copy) - if command._attr: - setattr(self, command._attr, copy) self.__cog_app_commands__ = children if cls.__cog_is_app_commands_group__: diff --git a/tests/test_app_commands_group.py b/tests/test_app_commands_group.py new file mode 100644 index 000000000..11f482b3b --- /dev/null +++ b/tests/test_app_commands_group.py @@ -0,0 +1,317 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from discord import app_commands +import discord +from discord.ext import commands + + +def test_group_with_commands(): + my_group = app_commands.Group(name='mygroup', description='My group') + + @my_group.command() + async def my_command(interaction: discord.Interaction) -> None: + ... + + assert my_command.binding is None + assert my_command.parent is my_group + assert my_group.commands[0] is my_command + + +def test_group_subclass_with_commands(): + class MyGroup(app_commands.Group, name='mygroup'): + @app_commands.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + my_group = MyGroup() + assert MyGroup.__discord_app_commands_group_children__[0].parent is not my_group + assert my_group.my_command is not MyGroup.my_command + assert my_group.my_command.parent is my_group + + +def test_group_subclass_with_group(): + class MyGroup(app_commands.Group, name='mygroup'): + sub_group = app_commands.Group(name='mysubgroup', description='My sub-group') + + @sub_group.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + my_group = MyGroup() + assert MyGroup.__discord_app_commands_group_children__[0].parent is not my_group + assert MyGroup.sub_group.parent is None + assert MyGroup.my_command.parent is MyGroup.sub_group + assert my_group.sub_group is not MyGroup.sub_group + assert my_group.my_command is not MyGroup.my_command + assert my_group.sub_group.parent is my_group + assert my_group.my_command.parent is my_group.sub_group + assert my_group.my_command.binding is my_group + + +def test_group_subclass_with_group_subclass(): + class MySubGroup(app_commands.Group, name='mysubgroup'): + @app_commands.command() + async def my_sub_group_command(self, interaction: discord.Interaction) -> None: + ... + + class MyGroup(app_commands.Group, name='mygroup'): + sub_group = MySubGroup() + + @app_commands.command() + async def my_group_command(self, interaction: discord.Interaction) -> None: + ... + + my_group = MyGroup() + assert MyGroup.__discord_app_commands_group_children__[0].parent is not my_group + assert MySubGroup.__discord_app_commands_group_children__[0].parent is not my_group.sub_group + assert my_group.sub_group is not MyGroup.sub_group + assert my_group.my_group_command is not MyGroup.my_group_command + assert my_group.sub_group.my_sub_group_command is not MySubGroup.my_sub_group_command + assert my_group.sub_group.parent is my_group + assert my_group.my_group_command.parent is my_group + assert my_group.my_group_command.binding is my_group + assert my_group.sub_group.my_sub_group_command.parent is my_group.sub_group + print(my_group.sub_group.my_sub_group_command.binding) + print(MyGroup.sub_group) + assert my_group.sub_group.my_sub_group_command.binding is my_group.sub_group + + +def test_cog_with_commands(): + class MyCog(commands.Cog): + @app_commands.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert cog.my_command.parent is None + assert cog.my_command.binding is cog + + +def test_cog_with_group_with_commands(): + class MyCog(commands.Cog): + my_group = app_commands.Group(name='mygroup', description='My group') + + @my_group.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert cog.my_group is not MyCog.my_group + assert cog.my_command is not MyCog.my_command + assert cog.my_group.parent is None + assert cog.my_command.parent is cog.my_group + assert cog.my_command.binding is cog + + +def test_cog_with_group_subclass_with_commands(): + class MyGroup(app_commands.Group, name='mygroup'): + @app_commands.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + class MyCog(commands.Cog): + my_group = MyGroup() + + @my_group.command() + async def my_cog_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyGroup.__discord_app_commands_group_children__[0].parent is not cog.my_group + assert cog.my_group is not MyCog.my_group + assert cog.my_group.my_command is not MyGroup.my_command + assert cog.my_cog_command is not MyCog.my_cog_command + assert not hasattr(cog.my_group, 'my_cog_command') + assert cog.my_group.parent is None + assert cog.my_group.my_command.parent is cog.my_group + assert cog.my_group.my_command.binding is cog.my_group + assert cog.my_cog_command.parent is cog.my_group + assert cog.my_cog_command.binding is cog + + +def test_cog_with_group_subclass_with_group(): + class MyGroup(app_commands.Group, name='mygroup'): + sub_group = app_commands.Group(name='mysubgroup', description='My sub-group') + + @sub_group.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + class MyCog(commands.Cog): + my_group = MyGroup() + + @my_group.command() + async def my_cog_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyGroup.__discord_app_commands_group_children__[0].parent is not cog.my_group + assert cog.my_group is not MyCog.my_group + assert cog.my_group.sub_group is not MyGroup.sub_group + assert cog.my_group.my_command is not MyGroup.my_command + assert cog.my_cog_command is not MyCog.my_cog_command + assert not hasattr(cog.my_group, 'my_cog_command') + assert cog.my_group.parent is None + assert cog.my_group.sub_group.parent is cog.my_group + assert cog.my_group.my_command.parent is cog.my_group.sub_group + assert cog.my_group.my_command.binding is cog.my_group + assert cog.my_cog_command.parent is cog.my_group + assert cog.my_cog_command.binding is cog + + +def test_cog_with_group_subclass_with_group_subclass(): + class MySubGroup(app_commands.Group, name='mysubgroup'): + @app_commands.command() + async def my_sub_group_command(self, interaction: discord.Interaction) -> None: + ... + + class MyGroup(app_commands.Group, name='mygroup'): + sub_group = MySubGroup() + + @app_commands.command() + async def my_group_command(self, interaction: discord.Interaction) -> None: + ... + + class MyCog(commands.Cog): + my_group = MyGroup() + + @my_group.command() + async def my_cog_command(self, interaction: discord.Interaction) -> None: + ... + + @my_group.sub_group.command() + async def my_sub_group_cog_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyGroup.__discord_app_commands_group_children__[0].parent is not cog.my_group + assert MySubGroup.__discord_app_commands_group_children__[0].parent is not cog.my_group.sub_group + assert cog.my_group is not MyCog.my_group + assert cog.my_group.my_group_command is not MyCog.my_group.my_group_command + assert cog.my_group.sub_group is not MyGroup.sub_group + assert cog.my_cog_command is not MyCog.my_cog_command + assert not hasattr(cog.my_group, 'my_cog_command') + assert cog.my_group.sub_group.my_sub_group_command is not MyGroup.sub_group.my_sub_group_command + assert cog.my_group.sub_group.my_sub_group_command is not MySubGroup.my_sub_group_command + assert cog.my_group.sub_group.parent is cog.my_group + assert cog.my_group.my_group_command.parent is cog.my_group + assert cog.my_group.my_group_command.binding is cog.my_group + assert cog.my_group.sub_group.my_sub_group_command.parent is cog.my_group.sub_group + assert cog.my_group.sub_group.my_sub_group_command.binding is cog.my_group.sub_group + assert cog.my_cog_command.parent is cog.my_group + assert cog.my_cog_command.binding is cog + assert cog.my_sub_group_cog_command.parent is cog.my_group.sub_group + assert cog.my_sub_group_cog_command.binding is cog + + +def test_cog_group_with_commands(): + class MyCog(commands.Cog, app_commands.Group): + @app_commands.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyCog.__discord_app_commands_group_children__[0].parent is not cog + assert cog.my_command is not MyCog.my_command + assert cog.parent is None + assert cog.my_command.parent is cog + + +def test_cog_group_with_group(): + class MyCog(commands.Cog, app_commands.Group): + sub_group = app_commands.Group(name='mysubgroup', description='My sub-group') + + @sub_group.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyCog.__discord_app_commands_group_children__[0].parent is not cog + assert cog.sub_group is not MyCog.sub_group + assert cog.my_command is not MyCog.my_command + assert cog.parent is None + assert cog.sub_group.parent is cog + assert cog.my_command.parent is cog.sub_group + + +def test_cog_group_with_subclass_group(): + class MyGroup(app_commands.Group, name='mygroup'): + @app_commands.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + class MyCog(commands.Cog, app_commands.Group): + sub_group = MyGroup() + + @sub_group.command() + async def my_cog_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyCog.__discord_app_commands_group_children__[0].parent is not cog + assert MyGroup.__discord_app_commands_group_children__[0].parent is not cog.sub_group + assert cog.sub_group is not MyCog.sub_group + assert cog.sub_group.my_command is not MyGroup.my_command + assert cog.my_cog_command is not MyCog.my_cog_command + assert not hasattr(cog.sub_group, 'my_cog_command') + assert cog.parent is None + assert cog.sub_group.parent is cog + assert cog.sub_group.my_command.parent is cog.sub_group + assert cog.my_cog_command.parent is cog.sub_group + assert cog.my_cog_command.binding is cog + + +def test_cog_group_with_subclassed_subclass_group(): + class MyGroup(app_commands.Group): + @app_commands.command() + async def my_command(self, interaction: discord.Interaction) -> None: + ... + + class MySubclassedGroup(MyGroup, name='mygroup'): + ... + + class MyCog(commands.Cog, app_commands.Group): + sub_group = MySubclassedGroup() + + @sub_group.command() + async def my_cog_command(self, interaction: discord.Interaction) -> None: + ... + + cog = MyCog() + assert MyCog.__discord_app_commands_group_children__[0].parent is not cog + assert MyGroup.__discord_app_commands_group_children__[0].parent is not cog.sub_group + assert MySubclassedGroup.__discord_app_commands_group_children__[0].parent is not cog.sub_group + assert cog.sub_group is not MyCog.sub_group + assert cog.sub_group.my_command is not MyGroup.my_command + assert cog.sub_group.my_command is not MySubclassedGroup.my_command + assert cog.my_cog_command is not MyCog.my_cog_command + assert not hasattr(cog.sub_group, 'my_cog_command') + assert cog.parent is None + assert cog.sub_group.parent is cog + assert cog.sub_group.my_command.parent is cog.sub_group + assert cog.my_cog_command.parent is cog.sub_group + assert cog.my_cog_command.binding is cog