From 0ef369c0fafb1de871f07d86122a119c1f6548fd Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 12 Mar 2022 09:24:26 -0500 Subject: [PATCH] [commands] Automatically unload top level app commands in extensions --- discord/app_commands/commands.py | 18 ++++++++++++++++++ discord/app_commands/tree.py | 28 +++++++++++++++++++++++++++- discord/ext/commands/bot.py | 10 ++++------ discord/ext/commands/cog.py | 2 ++ discord/utils.py | 4 ++++ 5 files changed, 55 insertions(+), 7 deletions(-) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 32b63889d..dbb590f1a 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -365,6 +365,7 @@ class Command(Generic[GroupT, P, T]): self.parent: Optional[Group] = parent self.binding: Optional[GroupT] = None self.on_error: Optional[Error[GroupT]] = None + self.module: Optional[str] = callback.__module__ # Unwrap __self__ for bound methods try: @@ -626,6 +627,7 @@ class ContextMenu: raise ValueError(f'context menu callback implies a type of {actual_type} but {type} was passed.') self._param_name = param self._annotation = annotation + self.module: Optional[str] = callback.__module__ @property def callback(self) -> ContextMenuCallback: @@ -642,6 +644,7 @@ class ContextMenu: self.type = type self._param_name = param self._annotation = annotation + self.module = callback.__module__ return self def to_dict(self) -> Dict[str, Any]: @@ -683,6 +686,7 @@ class Group: __discord_app_commands_skip_init_binding__: bool = False __discord_app_commands_group_name__: str = MISSING __discord_app_commands_group_description__: str = MISSING + __discord_app_commands_has_module__: bool = False def __init_subclass__(cls, *, name: str = MISSING, description: str = MISSING) -> None: if not cls.__discord_app_commands_group_children__: @@ -712,6 +716,9 @@ class Group: else: cls.__discord_app_commands_group_description__ = description + if cls.__module__ != __name__: + cls.__discord_app_commands_has_module__ = True + def __init__( self, *, @@ -730,6 +737,16 @@ class Group: raise TypeError('groups must have a description') self.parent: Optional[Group] = parent + self.module: Optional[str] + if cls.__discord_app_commands_has_module__: + self.module = cls.__module__ + else: + try: + # This is pretty hacky + # It allows the module to be fetched if someone just constructs a bare Group object though. + self.module = inspect.currentframe().f_back.f_globals['__name__'] # type: ignore + except (AttributeError, IndexError): + self.module = None self._children: Dict[str, Union[Command, Group]] = {} @@ -745,6 +762,7 @@ class Group: def __set_name__(self, owner: Type[Any], name: str) -> None: self._attr = name + self.module = owner.__module__ def _copy_with_binding(self, binding: Union[Group, Cog]) -> Group: cls = self.__class__ diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index 0af423c83..5695406a1 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -40,7 +40,7 @@ from .errors import ( ) from ..errors import ClientException from ..enums import AppCommandType, InteractionType -from ..utils import MISSING, _get_as_snowflake +from ..utils import MISSING, _get_as_snowflake, _is_submodule if TYPE_CHECKING: from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption @@ -489,6 +489,32 @@ class CommandTree(Generic[ClientT]): base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id) return base + def _remove_with_module(self, name: str) -> None: + remove: List[Any] = [] + for key, cmd in self._context_menus.items(): + if cmd.module is not None and _is_submodule(name, cmd.module): + remove.append(key) + + for key in remove: + del self._context_menus[key] + + remove = [] + for key, cmd in self._global_commands.items(): + if cmd.module is not None and _is_submodule(name, cmd.module): + remove.append(key) + + for key in remove: + del self._global_commands[key] + + for mapping in self._guild_commands.values(): + remove = [] + for key, cmd in mapping.items(): + if cmd.module is not None and _is_submodule(name, cmd.module): + remove.append(key) + + for key in remove: + del mapping[key] + async def on_error( self, interaction: Interaction, diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index b8b9b7cc8..f13643675 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -38,6 +38,7 @@ from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, import discord from discord import app_commands from discord.app_commands.tree import _retrieve_guild_ids +from discord.utils import MISSING, _is_submodule from .core import GroupMixin from .view import StringView @@ -65,8 +66,6 @@ __all__ = ( 'AutoShardedBot', ) -MISSING: Any = discord.utils.MISSING - T = TypeVar('T') CFT = TypeVar('CFT', bound='CoroFunc') CXT = TypeVar('CXT', bound='Context') @@ -120,10 +119,6 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M return inner -def _is_submodule(parent: str, child: str) -> bool: - return parent == child or child.startswith(parent + ".") - - class _DefaultRepr: def __repr__(self): return '' @@ -724,6 +719,9 @@ class BotBase(GroupMixin): for index in reversed(remove): del event_list[index] + # remove all relevant application commands from the tree + self.__tree._remove_with_module(name) + def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: func = getattr(lib, 'teardown') diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 3af426005..1ffbbe16d 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -263,6 +263,8 @@ class Cog(metaclass=CogMeta): if cls.__cog_is_app_commands_group__: # Dynamic attribute setting self.__discord_app_commands_group_children__ = children # type: ignore + # Enforce this to work even if someone forgets __init__ + self.module = cls.__module__ # type: ignore return self diff --git a/discord/utils.py b/discord/utils.py index 49c458dd6..6c517a380 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -572,6 +572,10 @@ def _bytes_to_base64_data(data: bytes) -> str: return fmt.format(mime=mime, data=b64) +def _is_submodule(parent: str, child: str) -> bool: + return parent == child or child.startswith(parent + '.') + + if HAS_ORJSON: def _to_json(obj: Any) -> str: # type: ignore