Browse Source

Support Group with app_commands.guilds decorator

pull/7611/head
Rapptz 3 years ago
parent
commit
303d33bb08
  1. 7
      discord/app_commands/commands.py
  2. 6
      discord/app_commands/tree.py

7
discord/app_commands/commands.py

@ -313,6 +313,8 @@ class Command(Generic[GroupT, P, T]):
The parent application command. ``None`` if there isn't one.
"""
__discord_app_commands_default_guilds__: List[int]
def __init__(
self,
*,
@ -620,6 +622,7 @@ class Group:
"""
__discord_app_commands_group_children__: ClassVar[Dict[str, Union[Command, Group]]] = {}
__discord_app_commands_default_guilds__: List[int]
__discord_app_commands_group_name__: str = MISSING
__discord_app_commands_group_description__: str = MISSING
@ -1109,8 +1112,8 @@ def guilds(*guild_ids: Union[Snowflake, int]) -> Callable[[T], T]:
defaults: List[int] = [g if isinstance(g, int) else g.id for g in guild_ids]
def decorator(inner: T) -> T:
if isinstance(inner, Command):
inner._callback.__discord_app_commands_default_guilds__ = defaults
if isinstance(inner, (Command, Group)):
inner.__discord_app_commands_default_guilds__ = defaults
else:
# Runtime attribute assignment
inner.__discord_app_commands_default_guilds__ = defaults # type: ignore

6
discord/app_commands/tree.py

@ -55,7 +55,7 @@ ClientT = TypeVar('ClientT', bound='Client')
def _retrieve_guild_ids(
callback: Any, guild: Optional[Snowflake] = MISSING, guilds: List[Snowflake] = MISSING
command: Any, guild: Optional[Snowflake] = MISSING, guilds: List[Snowflake] = MISSING
) -> Optional[Set[int]]:
if guild is not MISSING and guilds is not MISSING:
raise TypeError('cannot mix guild and guilds keyword arguments')
@ -65,7 +65,7 @@ def _retrieve_guild_ids(
# If no arguments are given then it should default to the ones
# given to the guilds(...) decorator or None for global.
if guild is MISSING:
return getattr(callback, '__discord_app_commands_default_guilds__', None)
return getattr(command, '__discord_app_commands_default_guilds__', None)
# guilds=[] is the same as global
if len(guilds) == 0:
@ -185,7 +185,7 @@ class CommandTree(Generic[ClientT]):
This is currently 100 for slash commands and 5 for context menu commands.
"""
guild_ids = _retrieve_guild_ids(getattr(command, '_callback', None), guild, guilds)
guild_ids = _retrieve_guild_ids(command, guild, guilds)
if isinstance(command, ContextMenu):
type = command.type.value
name = command.name

Loading…
Cancel
Save