From f072edfdfc5d2f21fa817182c1daae1c39bc0c9f Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 11 Apr 2022 18:55:57 -0400 Subject: [PATCH] [commands] Properly support commands.param in hybrid commands --- discord/ext/commands/hybrid.py | 61 ++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/discord/ext/commands/hybrid.py b/discord/ext/commands/hybrid.py index bd17b59da..89d1ab4d0 100644 --- a/discord/ext/commands/hybrid.py +++ b/discord/ext/commands/hybrid.py @@ -42,8 +42,9 @@ import inspect from discord import app_commands from discord.utils import MISSING, maybe_coroutine, async_all from .core import Command, Group -from .errors import CommandRegistrationError, CommandError, HybridCommandError, ConversionError +from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError from .converter import Converter +from .parameters import Parameter from .cog import Cog if TYPE_CHECKING: @@ -51,7 +52,6 @@ if TYPE_CHECKING: from ._types import ContextT, Coro, BotT from .bot import Bot from .context import Context - from .parameters import Parameter from discord.app_commands.commands import ( Check as AppCommandCheck, AutocompleteCallback, @@ -71,6 +71,7 @@ CogT = TypeVar('CogT', bound='Cog') CommandT = TypeVar('CommandT', bound='Command') # CHT = TypeVar('CHT', bound='Check') GroupT = TypeVar('GroupT', bound='Group') +_NoneType = type(None) if TYPE_CHECKING: P = ParamSpec('P') @@ -85,6 +86,17 @@ else: P2 = TypeVar('P2') +class _CallableDefault: + __slots__ = ('func',) + + def __init__(self, func: Callable[[Context], Any]) -> None: + self.func: Callable[[Context], Any] = func + + @property + def __class__(self) -> Any: + return _NoneType + + def is_converter(converter: Any) -> bool: return (inspect.isclass(converter) and issubclass(converter, Converter)) or isinstance(converter, Converter) @@ -107,12 +119,33 @@ def make_converter_transformer(converter: Any) -> Type[app_commands.Transformer] return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)}) +def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]: + async def transform(cls, interaction: discord.Interaction, value: str) -> Any: + try: + return func(value) + except CommandError: + raise + except Exception as exc: + raise BadArgument(f'Converting to "{func.__name__}" failed') from exc + + return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)}) + + def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]: # Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly params = signature.parameters.copy() for name, parameter in parameters.items(): - if is_converter(parameter.converter) and not hasattr(parameter.converter, '__discord_app_commands_transformer__'): + is_transformer = hasattr(parameter.converter, '__discord_app_commands_transformer__') + if is_converter(parameter.converter) and not is_transformer: params[name] = params[name].replace(annotation=make_converter_transformer(parameter.converter)) + if callable(parameter.converter) and not inspect.isclass(parameter.converter) and not is_transformer: + params[name] = params[name].replace(annotation=make_callable_transformer(parameter.converter)) + if callable(parameter.default): + params[name] = params[name].replace(default=_CallableDefault(parameter.default)) + + if isinstance(params[name].default, Parameter): + # If we're here, then then it hasn't been handled yet so it should be removed completely + params[name] = params[name].replace(default=parameter.empty) return list(params.values()) @@ -146,6 +179,28 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): } return self._copy_with(parent=self.parent, binding=self.binding, bindings=bindings) + async def _transform_arguments( + self, interaction: discord.Interaction, namespace: app_commands.Namespace + ) -> Dict[str, Any]: + values = namespace.__dict__ + transformed_values = {} + + for param in self._params.values(): + try: + value = values[param.display_name] + except KeyError: + if not param.required: + if isinstance(param.default, _CallableDefault): + transformed_values[param.name] = await maybe_coroutine(param.default.func, interaction._baton) + else: + transformed_values[param.name] = param.default + else: + raise app_commands.CommandSignatureMismatch(self) from None + else: + transformed_values[param.name] = await param.transform(interaction, value) + + return transformed_values + async def _check_can_run(self, interaction: discord.Interaction) -> bool: # Hybrid checks must run like so: # - Bot global check once