Browse Source

[commands] Add support for FlagConverter in hybrid commands

This works by unpacking and repacking the flag arguments in a flag.
If an unsupported type annotation is found then it will error at
definition time.
pull/8026/head
Rapptz 3 years ago
parent
commit
15ceca1e63
  1. 26
      discord/ext/commands/flags.py
  2. 143
      discord/ext/commands/hybrid.py
  3. 69
      docs/ext/commands/commands.rst

26
discord/ext/commands/flags.py

@ -28,7 +28,7 @@ import inspect
import re import re
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Pattern, Set, Tuple, Type, Union
from discord.utils import MISSING, maybe_coroutine, resolve_annotation from discord.utils import MISSING, maybe_coroutine, resolve_annotation
@ -44,7 +44,7 @@ __all__ = (
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self, TypeGuard
from ._types import BotT from ._types import BotT
from .context import Context from .context import Context
@ -76,6 +76,9 @@ class Flag:
A negative value indicates an unlimited amount of arguments. A negative value indicates an unlimited amount of arguments.
override: :class:`bool` override: :class:`bool`
Whether multiple given values overrides the previous value. Whether multiple given values overrides the previous value.
description: :class:`str`
The description of the flag. Shown for hybrid commands when they're
used as application commands.
""" """
name: str = MISSING name: str = MISSING
@ -85,6 +88,7 @@ class Flag:
default: Any = MISSING default: Any = MISSING
max_args: int = MISSING max_args: int = MISSING
override: bool = MISSING override: bool = MISSING
description: str = MISSING
cast_to_dict: bool = False cast_to_dict: bool = False
@property @property
@ -104,6 +108,7 @@ def flag(
max_args: int = MISSING, max_args: int = MISSING,
override: bool = MISSING, override: bool = MISSING,
converter: Any = MISSING, converter: Any = MISSING,
description: str = MISSING,
) -> Any: ) -> Any:
"""Override default functionality and parameters of the underlying :class:`FlagConverter` """Override default functionality and parameters of the underlying :class:`FlagConverter`
class attributes. class attributes.
@ -128,8 +133,23 @@ def flag(
converter: Any converter: Any
The converter to use for this flag. This replaces the annotation at The converter to use for this flag. This replaces the annotation at
runtime which is transparent to type checkers. runtime which is transparent to type checkers.
description: :class:`str`
The description of the flag. Shown for hybrid commands when they're
used as application commands.
""" """
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override, annotation=converter) return Flag(
name=name,
aliases=aliases,
default=default,
max_args=max_args,
override=override,
annotation=converter,
description=description,
)
def is_flag(obj: Any) -> TypeGuard[Type[FlagConverter]]:
return hasattr(obj, '__commands_is_flag__')
def validate_flag_name(name: str, forbidden: Set[str]) -> None: def validate_flag_name(name: str, forbidden: Set[str]) -> None:

143
discord/ext/commands/hybrid.py

@ -31,6 +31,7 @@ from typing import (
ClassVar, ClassVar,
Dict, Dict,
List, List,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -45,6 +46,7 @@ from .core import Command, Group
from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError
from .converter import Converter, Range, Greedy, run_converters from .converter import Converter, Range, Greedy, run_converters
from .parameters import Parameter from .parameters import Parameter
from .flags import is_flag, FlagConverter
from .cog import Cog from .cog import Cog
from .view import StringView from .view import StringView
@ -163,7 +165,86 @@ def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_co
return type('GreedyTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)}) return type('GreedyTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]: def replace_parameter(
param: inspect.Parameter,
converter: Any,
callback: Callable[..., Any],
original: Parameter,
mapping: Dict[str, inspect.Parameter],
) -> inspect.Parameter:
try:
# If it's a supported annotation (i.e. a transformer) just let it pass as-is.
app_commands.transformers.get_supported_annotation(converter)
except TypeError:
# Fallback to see if the behaviour needs changing
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', [])
if isinstance(converter, Range):
r = converter
param = param.replace(annotation=app_commands.Range[r.annotation, r.min, r.max]) # type: ignore
elif isinstance(converter, Greedy):
# Greedy is "optional" in ext.commands
# However, in here, it probably makes sense to make it required.
# I'm unsure how to allow the user to choose right now.
inner = converter.converter
if inner is discord.Attachment:
raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands')
param = param.replace(annotation=make_greedy_transformer(inner, original))
elif is_flag(converter):
callback.__hybrid_command_flag__ = (param.name, converter)
descriptions = {}
renames = {}
for flag in converter.__commands_flags__.values():
name = flag.attribute
flag_param = inspect.Parameter(
name=name,
kind=param.kind,
default=flag.default if flag.default is not MISSING else inspect.Parameter.empty,
annotation=flag.annotation,
)
pseudo = replace_parameter(flag_param, flag.annotation, callback, original, mapping)
if name in mapping:
raise TypeError(f'{name!r} flag would shadow a pre-existing parameter')
if flag.description is not MISSING:
descriptions[name] = flag.description
if flag.name != flag.attribute:
renames[name] = flag.name
mapping[name] = pseudo
# Manually call the decorators
if descriptions:
app_commands.describe(**descriptions)(callback)
if renames:
app_commands.rename(**renames)(callback)
elif is_converter(converter):
param = param.replace(annotation=make_converter_transformer(converter))
elif origin is Union:
if len(args) == 2 and args[-1] is _NoneType:
# Special case Optional[X] where X is a single type that can optionally be a converter
inner = args[0]
is_inner_tranformer = is_transformer(inner)
if is_converter(inner) and not is_inner_tranformer:
param = param.replace(annotation=Optional[make_converter_transformer(inner)]) # type: ignore
else:
raise
elif origin:
# Unsupported typing.X annotation e.g. typing.Dict, typing.Tuple, typing.List, etc.
raise
elif callable(converter) and not inspect.isclass(converter):
param_count = required_pos_arguments(converter)
if param_count != 1:
raise
param = param.replace(annotation=make_callable_transformer(converter))
return param
def replace_parameters(
parameters: Dict[str, Parameter], callback: Callable[..., Any], signature: inspect.Signature
) -> List[inspect.Parameter]:
# Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly # Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly
params = signature.parameters.copy() params = signature.parameters.copy()
for name, parameter in parameters.items(): for name, parameter in parameters.items():
@ -171,41 +252,7 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign
# Parameter.converter properly infers from the default and has a str default # Parameter.converter properly infers from the default and has a str default
# This allows the actual signature to inherit this property # This allows the actual signature to inherit this property
param = params[name].replace(annotation=converter) param = params[name].replace(annotation=converter)
try: param = replace_parameter(param, converter, callback, parameter, params)
# If it's a supported annotation (i.e. a transformer) just let it pass as-is.
app_commands.transformers.get_supported_annotation(converter)
except TypeError:
# Fallback to see if the behaviour needs changing
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', [])
if isinstance(converter, Range):
r = converter
param = param.replace(annotation=app_commands.Range[r.annotation, r.min, r.max]) # type: ignore
elif isinstance(converter, Greedy):
# Greedy is "optional" in ext.commands
# However, in here, it probably makes sense to make it required.
# I'm unsure how to allow the user to choose right now.
inner = converter.converter
if inner is discord.Attachment:
raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands')
param = param.replace(annotation=make_greedy_transformer(inner, parameter))
elif is_converter(converter):
param = param.replace(annotation=make_converter_transformer(converter))
elif origin is Union:
if len(args) == 2 and args[-1] is _NoneType:
# Special case Optional[X] where X is a single type that can optionally be a converter
inner = args[0]
is_inner_tranformer = is_transformer(inner)
if is_converter(inner) and not is_inner_tranformer:
param = param.replace(annotation=Optional[make_converter_transformer(inner)]) # type: ignore
else:
raise
elif callable(converter) and not inspect.isclass(converter):
param_count = required_pos_arguments(converter)
if param_count != 1:
raise
param = param.replace(annotation=make_callable_transformer(converter))
if parameter.default is not parameter.empty: if parameter.default is not parameter.empty:
default = _CallableDefault(parameter.default) if callable(parameter.default) else parameter.default default = _CallableDefault(parameter.default) if callable(parameter.default) else parameter.default
@ -215,6 +262,11 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign
# If we're here, then then it hasn't been handled yet so it should be removed completely # If we're here, then then it hasn't been handled yet so it should be removed completely
param = param.replace(default=parameter.empty) param = param.replace(default=parameter.empty)
# Flags are flattened out and thus don't get their parameter in the actual mapping
if hasattr(converter, '__commands_is_flag__'):
del params[name]
continue
params[name] = param params[name] = param
return list(params.values()) return list(params.values())
@ -223,7 +275,7 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign
class HybridAppCommand(discord.app_commands.Command[CogT, P, T]): class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
def __init__(self, wrapped: Command[CogT, Any, T]) -> None: def __init__(self, wrapped: Command[CogT, Any, T]) -> None:
signature = inspect.signature(wrapped.callback) signature = inspect.signature(wrapped.callback)
params = replace_parameters(wrapped.params, signature) params = replace_parameters(wrapped.params, wrapped.callback, signature)
wrapped.callback.__signature__ = signature.replace(parameters=params) wrapped.callback.__signature__ = signature.replace(parameters=params)
try: try:
@ -237,6 +289,10 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
self.wrapped: Command[CogT, Any, T] = wrapped self.wrapped: Command[CogT, Any, T] = wrapped
self.binding: Optional[CogT] = wrapped.cog self.binding: Optional[CogT] = wrapped.cog
# This technically means only one flag converter is supported
self.flag_converter: Optional[Tuple[str, Type[FlagConverter]]] = getattr(
wrapped.callback, '__hybrid_command_flag__', None
)
def _copy_with(self, **kwargs) -> Self: def _copy_with(self, **kwargs) -> Self:
copy: Self = super()._copy_with(**kwargs) # type: ignore copy: Self = super()._copy_with(**kwargs) # type: ignore
@ -269,6 +325,19 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
else: else:
transformed_values[param.name] = await param.transform(interaction, value) transformed_values[param.name] = await param.transform(interaction, value)
if self.flag_converter is not None:
param_name, flag_cls = self.flag_converter
flag = flag_cls.__new__(flag_cls)
for f in flag_cls.__commands_flags__.values():
try:
value = transformed_values.pop(f.attribute)
except KeyError:
raise app_commands.CommandSignatureMismatch(self) from None
else:
setattr(flag, f.attribute, value)
transformed_values[param_name] = flag
return transformed_values return transformed_values
async def _check_can_run(self, interaction: discord.Interaction) -> bool: async def _check_can_run(self, interaction: discord.Interaction) -> bool:

69
docs/ext/commands/commands.rst

@ -876,6 +876,75 @@ A :class:`dict` annotation is functionally equivalent to ``List[Tuple[K, V]]`` e
given as a :class:`dict` rather than a :class:`list`. given as a :class:`dict` rather than a :class:`list`.
Hybrid Command Interaction
^^^^^^^^^^^^^^^^^^^^^^^^^^^
When used as a hybrid command, the parameters are flattened into different parameters for the application command. For example, the following converter:
.. code-block:: python3
class BanFlags(commands.FlagConverter):
member: discord.Member
reason: str
days: int = 1
@commands.hybrid_command()
async def ban(ctx, *, flags: BanFlags):
...
Would be equivalent to an application command defined as this:
.. code-block:: python3
@commands.hybrid_command()
async def ban(ctx, member: discord.Member, reason: str, days: int = 1):
...
This means that decorators that refer to a parameter by name will use the flag name instead:
.. code-block:: python3
class BanFlags(commands.FlagConverter):
member: discord.Member
reason: str
days: int = 1
@commands.hybrid_command()
@app_commands.describe(
member='The member to ban',
reason='The reason for the ban',
days='The number of days worth of messages to delete',
)
async def ban(ctx, *, flags: BanFlags):
...
For ease of use, the :func:`~ext.commands.flag` function accepts a ``descriptor`` keyword argument to allow you to pass descriptions inline:
.. code-block:: python3
class BanFlags(commands.FlagConverter):
member: discord.Member = commands.flag(description='The member to ban')
reason: str = commands.flag(description='The reason for the ban')
days: int = 1 = commands.flag(description='The number of days worth of messages to delete')
@commands.hybrid_command()
async def ban(ctx, *, flags: BanFlags):
...
Note that in hybrid command form, a few annotations are unsupported due to Discord limitations:
- :data:`typing.Tuple`
- :data:`typing.List`
- :data:`typing.Dict`
.. note::
Only one flag converter is supported per hybrid command. Due to the flag converter's way of working, it is unlikely for a user to have two of them in one signature.
.. _ext_commands_parameter: .. _ext_commands_parameter:
Parameter Metadata Parameter Metadata

Loading…
Cancel
Save