From 06c43d6772a8e812d2f746757166d63a1d25d002 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 22 May 2022 19:31:28 -0400 Subject: [PATCH] [commands] Add support for NSFW commands for hybrid commands --- discord/ext/commands/cog.py | 8 +++++++ discord/ext/commands/core.py | 42 +++++++++++++++++++++++++++++++--- discord/ext/commands/hybrid.py | 5 +++- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 1fdf40111..6f14d7ddb 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -119,6 +119,11 @@ class CogMeta(type): The group description of a cog. This is only applicable for :class:`GroupCog` instances. By default, it's the same value as :attr:`description`. + .. versionadded:: 2.0 + group_nsfw: :class:`bool` + Whether the application command group is NSFW. This is only applicable for :class:`GroupCog` instances. + By default, it's ``False``. + .. versionadded:: 2.0 """ @@ -126,6 +131,7 @@ class CogMeta(type): __cog_description__: str __cog_group_name__: str __cog_group_description__: str + __cog_group_nsfw__: bool __cog_settings__: Dict[str, Any] __cog_commands__: List[Command[Any, ..., Any]] __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]] @@ -154,6 +160,7 @@ class CogMeta(type): attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) attrs['__cog_name__'] = cog_name attrs['__cog_group_name__'] = group_name + attrs['__cog_group_nsfw__'] = kwargs.pop('group_nsfw', False) description = kwargs.pop('description', None) if description is None: @@ -268,6 +275,7 @@ class Cog(metaclass=CogMeta): group = app_commands.Group( name=cls.__cog_group_name__, description=cls.__cog_group_description__, + nsfw=cls.__cog_group_nsfw__, parent=None, guild_ids=getattr(cls, '__discord_app_commands_default_guilds__', None), guild_only=getattr(cls, '__discord_app_commands_guild_only__', False), diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 0a2be42a1..4a0b2a96a 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -2282,7 +2282,9 @@ def guild_only() -> Check[Any]: that is inherited from :exc:`.CheckFailure`. If used on hybrid commands, this will be equivalent to the - :func:`discord.app_commands.guild_only` decorator. + :func:`discord.app_commands.guild_only` decorator. In an unsupported + context, such as a subcommand, this will still fallback to applying the + check. """ # Due to implementation quirks, this check has to be re-implemented completely @@ -2346,13 +2348,21 @@ def is_nsfw() -> Check[Any]: This check raises a special exception, :exc:`.NSFWChannelRequired` that is derived from :exc:`.CheckFailure`. + If used on hybrid commands, this will be equivalent to setting the + application command's ``nsfw`` attribute to ``True``. In an unsupported + context, such as a subcommand, this will still fallback to applying the + check. + .. versionchanged:: 1.1 Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. DM channels will also now pass this check. """ - def pred(ctx: Context[BotT]) -> bool: + # Due to implementation quirks, this check has to be re-implemented completely + # to work with both app_commands and the command framework. + + def predicate(ctx: Context[BotT]) -> bool: ch = ctx.channel if ctx.guild is None or ( isinstance(ch, (discord.TextChannel, discord.Thread, discord.VoiceChannel)) and ch.is_nsfw() @@ -2360,7 +2370,33 @@ def is_nsfw() -> Check[Any]: return True raise NSFWChannelRequired(ch) # type: ignore - return check(pred) + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: + if isinstance(func, Command): + func.checks.append(predicate) + if hasattr(func, '__commands_is_hybrid__'): + app_command = getattr(func, 'app_command', None) + if app_command: + app_command.nsfw = True + else: + if not hasattr(func, '__commands_checks__'): + func.__commands_checks__ = [] + + func.__commands_checks__.append(predicate) + func.__discord_app_commands_is_nsfw__ = True + + return func + + if inspect.iscoroutinefunction(predicate): + decorator.predicate = predicate + else: + + @functools.wraps(predicate) + async def wrapper(ctx: Context[BotT]): + return predicate(ctx) + + decorator.predicate = wrapper + + return decorator # type: ignore def cooldown( diff --git a/discord/ext/commands/hybrid.py b/discord/ext/commands/hybrid.py index 2848fe004..609991910 100644 --- a/discord/ext/commands/hybrid.py +++ b/discord/ext/commands/hybrid.py @@ -290,12 +290,13 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): signature = inspect.signature(wrapped.callback) params = replace_parameters(wrapped.params, wrapped.callback, signature) wrapped.callback.__signature__ = signature.replace(parameters=params) - + nsfw = getattr(wrapped.callback, '__discord_app_commands_is_nsfw__', False) try: super().__init__( name=wrapped.name, callback=wrapped.callback, # type: ignore # Signature doesn't match but we're overriding the invoke description=wrapped.description or wrapped.short_doc or '…', + nsfw=nsfw, ) finally: del wrapped.callback.__signature__ @@ -595,12 +596,14 @@ class HybridGroup(Group[CogT, P, T]): ) guild_only = getattr(self.callback, '__discord_app_commands_guild_only__', False) default_permissions = getattr(self.callback, '__discord_app_commands_default_permissions__', None) + nsfw = getattr(self.callback, '__discord_app_commands_is_nsfw__', False) self.app_command = app_commands.Group( name=self.name, description=self.description or self.short_doc or '…', guild_ids=guild_ids, guild_only=guild_only, default_permissions=default_permissions, + nsfw=nsfw, ) # This prevents the group from re-adding the command at __init__