From 15ceca1e63f9cae90f8df7346f83b412f1d1d703 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 11 May 2022 03:15:49 -0400 Subject: [PATCH] [commands] Add support for FlagConverter in hybrid commands This works by unpacking and repacking the flag arguments in a flag. If an unsupported type annotation is found then it will error at definition time. --- discord/ext/commands/flags.py | 26 +++++- discord/ext/commands/hybrid.py | 143 ++++++++++++++++++++++++--------- docs/ext/commands/commands.rst | 69 ++++++++++++++++ 3 files changed, 198 insertions(+), 40 deletions(-) diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index b554df18e..8c63bcfde 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -28,7 +28,7 @@ import inspect import re import sys from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Type, Union from discord.utils import MISSING, maybe_coroutine, resolve_annotation @@ -44,7 +44,7 @@ __all__ = ( if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeGuard from ._types import BotT from .context import Context @@ -76,6 +76,9 @@ class Flag: A negative value indicates an unlimited amount of arguments. override: :class:`bool` Whether multiple given values overrides the previous value. + description: :class:`str` + The description of the flag. Shown for hybrid commands when they're + used as application commands. """ name: str = MISSING @@ -85,6 +88,7 @@ class Flag: default: Any = MISSING max_args: int = MISSING override: bool = MISSING + description: str = MISSING cast_to_dict: bool = False @property @@ -104,6 +108,7 @@ def flag( max_args: int = MISSING, override: bool = MISSING, converter: Any = MISSING, + description: str = MISSING, ) -> Any: """Override default functionality and parameters of the underlying :class:`FlagConverter` class attributes. @@ -128,8 +133,23 @@ def flag( converter: Any The converter to use for this flag. This replaces the annotation at runtime which is transparent to type checkers. + description: :class:`str` + The description of the flag. Shown for hybrid commands when they're + used as application commands. """ - return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override, annotation=converter) + return Flag( + name=name, + aliases=aliases, + default=default, + max_args=max_args, + override=override, + annotation=converter, + description=description, + ) + + +def is_flag(obj: Any) -> TypeGuard[Type[FlagConverter]]: + return hasattr(obj, '__commands_is_flag__') def validate_flag_name(name: str, forbidden: Set[str]) -> None: diff --git a/discord/ext/commands/hybrid.py b/discord/ext/commands/hybrid.py index 8ed9e0816..1946ddfc2 100644 --- a/discord/ext/commands/hybrid.py +++ b/discord/ext/commands/hybrid.py @@ -31,6 +31,7 @@ from typing import ( ClassVar, Dict, List, + Tuple, Type, TypeVar, Union, @@ -45,6 +46,7 @@ from .core import Command, Group from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError from .converter import Converter, Range, Greedy, run_converters from .parameters import Parameter +from .flags import is_flag, FlagConverter from .cog import Cog from .view import StringView @@ -163,7 +165,86 @@ def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_co return type('GreedyTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)}) -def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]: +def replace_parameter( + param: inspect.Parameter, + converter: Any, + callback: Callable[..., Any], + original: Parameter, + mapping: Dict[str, inspect.Parameter], +) -> inspect.Parameter: + try: + # If it's a supported annotation (i.e. a transformer) just let it pass as-is. + app_commands.transformers.get_supported_annotation(converter) + except TypeError: + # Fallback to see if the behaviour needs changing + origin = getattr(converter, '__origin__', None) + args = getattr(converter, '__args__', []) + if isinstance(converter, Range): + r = converter + param = param.replace(annotation=app_commands.Range[r.annotation, r.min, r.max]) # type: ignore + elif isinstance(converter, Greedy): + # Greedy is "optional" in ext.commands + # However, in here, it probably makes sense to make it required. + # I'm unsure how to allow the user to choose right now. + inner = converter.converter + if inner is discord.Attachment: + raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands') + + param = param.replace(annotation=make_greedy_transformer(inner, original)) + elif is_flag(converter): + callback.__hybrid_command_flag__ = (param.name, converter) + descriptions = {} + renames = {} + for flag in converter.__commands_flags__.values(): + name = flag.attribute + flag_param = inspect.Parameter( + name=name, + kind=param.kind, + default=flag.default if flag.default is not MISSING else inspect.Parameter.empty, + annotation=flag.annotation, + ) + pseudo = replace_parameter(flag_param, flag.annotation, callback, original, mapping) + if name in mapping: + raise TypeError(f'{name!r} flag would shadow a pre-existing parameter') + if flag.description is not MISSING: + descriptions[name] = flag.description + if flag.name != flag.attribute: + renames[name] = flag.name + + mapping[name] = pseudo + + # Manually call the decorators + if descriptions: + app_commands.describe(**descriptions)(callback) + if renames: + app_commands.rename(**renames)(callback) + + elif is_converter(converter): + param = param.replace(annotation=make_converter_transformer(converter)) + elif origin is Union: + if len(args) == 2 and args[-1] is _NoneType: + # Special case Optional[X] where X is a single type that can optionally be a converter + inner = args[0] + is_inner_tranformer = is_transformer(inner) + if is_converter(inner) and not is_inner_tranformer: + param = param.replace(annotation=Optional[make_converter_transformer(inner)]) # type: ignore + else: + raise + elif origin: + # Unsupported typing.X annotation e.g. typing.Dict, typing.Tuple, typing.List, etc. + raise + elif callable(converter) and not inspect.isclass(converter): + param_count = required_pos_arguments(converter) + if param_count != 1: + raise + param = param.replace(annotation=make_callable_transformer(converter)) + + return param + + +def replace_parameters( + parameters: Dict[str, Parameter], callback: Callable[..., Any], signature: inspect.Signature +) -> List[inspect.Parameter]: # Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly params = signature.parameters.copy() for name, parameter in parameters.items(): @@ -171,41 +252,7 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign # Parameter.converter properly infers from the default and has a str default # This allows the actual signature to inherit this property param = params[name].replace(annotation=converter) - try: - # If it's a supported annotation (i.e. a transformer) just let it pass as-is. - app_commands.transformers.get_supported_annotation(converter) - except TypeError: - # Fallback to see if the behaviour needs changing - origin = getattr(converter, '__origin__', None) - args = getattr(converter, '__args__', []) - if isinstance(converter, Range): - r = converter - param = param.replace(annotation=app_commands.Range[r.annotation, r.min, r.max]) # type: ignore - elif isinstance(converter, Greedy): - # Greedy is "optional" in ext.commands - # However, in here, it probably makes sense to make it required. - # I'm unsure how to allow the user to choose right now. - inner = converter.converter - if inner is discord.Attachment: - raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands') - - param = param.replace(annotation=make_greedy_transformer(inner, parameter)) - elif is_converter(converter): - param = param.replace(annotation=make_converter_transformer(converter)) - elif origin is Union: - if len(args) == 2 and args[-1] is _NoneType: - # Special case Optional[X] where X is a single type that can optionally be a converter - inner = args[0] - is_inner_tranformer = is_transformer(inner) - if is_converter(inner) and not is_inner_tranformer: - param = param.replace(annotation=Optional[make_converter_transformer(inner)]) # type: ignore - else: - raise - elif callable(converter) and not inspect.isclass(converter): - param_count = required_pos_arguments(converter) - if param_count != 1: - raise - param = param.replace(annotation=make_callable_transformer(converter)) + param = replace_parameter(param, converter, callback, parameter, params) if parameter.default is not parameter.empty: default = _CallableDefault(parameter.default) if callable(parameter.default) else parameter.default @@ -215,6 +262,11 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign # If we're here, then then it hasn't been handled yet so it should be removed completely param = param.replace(default=parameter.empty) + # Flags are flattened out and thus don't get their parameter in the actual mapping + if hasattr(converter, '__commands_is_flag__'): + del params[name] + continue + params[name] = param return list(params.values()) @@ -223,7 +275,7 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): def __init__(self, wrapped: Command[CogT, Any, T]) -> None: signature = inspect.signature(wrapped.callback) - params = replace_parameters(wrapped.params, signature) + params = replace_parameters(wrapped.params, wrapped.callback, signature) wrapped.callback.__signature__ = signature.replace(parameters=params) try: @@ -237,6 +289,10 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): self.wrapped: Command[CogT, Any, T] = wrapped self.binding: Optional[CogT] = wrapped.cog + # This technically means only one flag converter is supported + self.flag_converter: Optional[Tuple[str, Type[FlagConverter]]] = getattr( + wrapped.callback, '__hybrid_command_flag__', None + ) def _copy_with(self, **kwargs) -> Self: copy: Self = super()._copy_with(**kwargs) # type: ignore @@ -269,6 +325,19 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): else: transformed_values[param.name] = await param.transform(interaction, value) + if self.flag_converter is not None: + param_name, flag_cls = self.flag_converter + flag = flag_cls.__new__(flag_cls) + for f in flag_cls.__commands_flags__.values(): + try: + value = transformed_values.pop(f.attribute) + except KeyError: + raise app_commands.CommandSignatureMismatch(self) from None + else: + setattr(flag, f.attribute, value) + + transformed_values[param_name] = flag + return transformed_values async def _check_can_run(self, interaction: discord.Interaction) -> bool: diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index b6d24d890..d3dc47c28 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -876,6 +876,75 @@ A :class:`dict` annotation is functionally equivalent to ``List[Tuple[K, V]]`` e given as a :class:`dict` rather than a :class:`list`. +Hybrid Command Interaction +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When used as a hybrid command, the parameters are flattened into different parameters for the application command. For example, the following converter: + +.. code-block:: python3 + + class BanFlags(commands.FlagConverter): + member: discord.Member + reason: str + days: int = 1 + + + @commands.hybrid_command() + async def ban(ctx, *, flags: BanFlags): + ... + +Would be equivalent to an application command defined as this: + +.. code-block:: python3 + + @commands.hybrid_command() + async def ban(ctx, member: discord.Member, reason: str, days: int = 1): + ... + +This means that decorators that refer to a parameter by name will use the flag name instead: + +.. code-block:: python3 + + class BanFlags(commands.FlagConverter): + member: discord.Member + reason: str + days: int = 1 + + + @commands.hybrid_command() + @app_commands.describe( + member='The member to ban', + reason='The reason for the ban', + days='The number of days worth of messages to delete', + ) + async def ban(ctx, *, flags: BanFlags): + ... + +For ease of use, the :func:`~ext.commands.flag` function accepts a ``descriptor`` keyword argument to allow you to pass descriptions inline: + +.. code-block:: python3 + + class BanFlags(commands.FlagConverter): + member: discord.Member = commands.flag(description='The member to ban') + reason: str = commands.flag(description='The reason for the ban') + days: int = 1 = commands.flag(description='The number of days worth of messages to delete') + + + @commands.hybrid_command() + async def ban(ctx, *, flags: BanFlags): + ... + + +Note that in hybrid command form, a few annotations are unsupported due to Discord limitations: + +- :data:`typing.Tuple` +- :data:`typing.List` +- :data:`typing.Dict` + +.. note:: + + Only one flag converter is supported per hybrid command. Due to the flag converter's way of working, it is unlikely for a user to have two of them in one signature. + .. _ext_commands_parameter: Parameter Metadata