diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index 7026d4c9b..4b6e2bd3f 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -77,6 +77,7 @@ __all__ = ( 'Group', 'command', 'describe', + 'choices', ) if TYPE_CHECKING: @@ -171,6 +172,31 @@ def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Di raise TypeError(f'unknown parameter given: {first}') +def _populate_choices(params: Dict[str, CommandParameter], all_choices: Dict[str, List[Choice]]) -> None: + for name, param in params.items(): + choices = all_choices.pop(name, MISSING) + if choices is MISSING: + continue + + if not isinstance(choices, list): + raise TypeError('choices must be a list of Choice') + + if not all(isinstance(choice, Choice) for choice in choices): + raise TypeError('choices must be a list of Choice') + + if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer): + raise TypeError('choices are only supported for integer, string, or number option types') + + # There's a type safety hole if someone does Choice[float] as an annotation + # but the values are actually Choice[int]. Since the input-output is the same this feels + # safe enough to ignore. + param.choices = choices + + if all_choices: + first = next(iter(all_choices)) + raise TypeError(f'unknown parameter given: {first}') + + def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]: params = inspect.signature(func).parameters cache = {} @@ -203,6 +229,13 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s else: _populate_descriptions(result, descriptions) + try: + choices = func.__discord_app_commands_param_choices__ + except AttributeError: + pass + else: + _populate_choices(result, choices) + return result @@ -313,15 +346,15 @@ class Command(Generic[GroupT, P, T]): async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T: values = namespace.__dict__ for name, param in self._params.items(): - if not param.required: - values.setdefault(name, param.default) - else: - try: - value = values[name] - except KeyError: - raise CommandSignatureMismatch(self) from None + try: + value = values[name] + except KeyError: + if not param.required: + values[name] = param.default else: - values[name] = await param.transform(interaction, value) + raise CommandSignatureMismatch(self) from None + else: + values[name] = await param.transform(interaction, value) # These type ignores are because the type checker doesn't quite understand the narrowing here # Likewise, it thinks we're missing positional arguments when there aren't any. @@ -768,7 +801,7 @@ def describe(**parameters: str) -> Callable[[T], T]: .. code-block:: python3 @app_commands.command() - @app_commads.describe(member='the member to ban') + @app_commands.describe(member='the member to ban') async def ban(interaction: discord.Interaction, member: discord.Member): await interaction.response.send_message(f'Banned {member}') @@ -787,7 +820,79 @@ def describe(**parameters: str) -> Callable[[T], T]: if isinstance(inner, Command): _populate_descriptions(inner._params, parameters) else: - inner.__discord_app_commands_param_description__ = parameters # type: ignore - Runtime attribute assignment + try: + inner.__discord_app_commands_param_description__.update(parameters) # type: ignore - Runtime attribute access + except AttributeError: + inner.__discord_app_commands_param_description__ = parameters # type: ignore - Runtime attribute assignment + + return inner + + return decorator + + +def choices(**parameters: List[Choice]) -> Callable[[T], T]: + r"""Instructs the given parameters by their name to use the given choices for their choices. + + Example: + + .. code-block:: python3 + + @app_commands.command() + @app_commands.describe(fruits='fruits to choose from') + @app_commands.choices(fruits=[ + Choice(name='apple', value=1), + Choice(name='banana', value=2), + Choice(name='cherry', value=3), + ]) + async def fruit(interaction: discord.Interaction, fruits: Choice[int]): + await interaction.response.send_message(f'Your favourite fruit is {fruits.name}.') + + .. note:: + + This is not the only way to provide choices to a command. There are two more ergonomic ways + of doing this. The first one is to use a :obj:`typing.Literal` annotation: + + .. code-block:: python3 + + @app_commands.command() + @app_commands.describe(fruits='fruits to choose from') + async def fruit(interaction: discord.Interaction, fruits: Literal['apple', 'banana', 'cherry']): + await interaction.response.send_message(f'Your favourite fruit is {fruits}.') + + The second way is to use an :class:`enum.Enum`: + + .. code-block:: python3 + + class Fruits(enum.Enum): + apple = 1 + banana = 2 + cherry = 3 + + @app_commands.command() + @app_commands.describe(fruits='fruits to choose from') + async def fruit(interaction: discord.Interaction, fruits: Fruits): + await interaction.response.send_message(f'Your favourite fruit is {fruits}.') + + + Parameters + ----------- + \*\*parameters + The choices of the parameters. + + Raises + -------- + TypeError + The parameter name is not found. + """ + + def decorator(inner: T) -> T: + if isinstance(inner, Command): + _populate_choices(inner._params, parameters) + else: + try: + inner.__discord_app_commands_param_choices__.update(parameters) # type: ignore - Runtime attribute access + except AttributeError: + inner.__discord_app_commands_param_choices__ = parameters # type: ignore - Runtime attribute assignment return inner diff --git a/discord/app_commands/models.py b/discord/app_commands/models.py index 8ad96f9b0..1e6e34082 100644 --- a/discord/app_commands/models.py +++ b/discord/app_commands/models.py @@ -31,7 +31,7 @@ from ..enums import ChannelType, try_enum from ..mixins import Hashable from ..utils import _get_as_snowflake, parse_time, snowflake_time from .enums import AppCommandOptionType, AppCommandType -from typing import List, NamedTuple, TYPE_CHECKING, Optional, Union +from typing import Generic, List, NamedTuple, TYPE_CHECKING, Optional, TypeVar, Union __all__ = ( 'AppCommand', @@ -42,6 +42,8 @@ __all__ = ( 'Choice', ) +ChoiceT = TypeVar('ChoiceT', str, int, float, Union[str, int, float]) + def is_app_command_argument_type(value: int) -> bool: return 11 >= value >= 3 @@ -145,7 +147,7 @@ class AppCommand(Hashable): return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>' -class Choice(NamedTuple): +class Choice(Generic[ChoiceT]): """Represents an application command argument choice. .. versionadded:: 2.0 @@ -160,6 +162,10 @@ class Choice(NamedTuple): Checks if two choices are not equal. + .. describe:: hash(x) + + Returns the choice's hash. + Parameters ----------- name: :class:`str` @@ -168,8 +174,20 @@ class Choice(NamedTuple): The value of the choice. """ - name: str - value: Union[int, str, float] + __slots__ = ('name', 'value') + + def __init__(self, *, name: str, value: ChoiceT): + self.name: str = name + self.value: ChoiceT = value + + def __eq__(self, o: object) -> bool: + return isinstance(o, Choice) and self.name == o.name and self.value == o.value + + def __hash__(self) -> int: + return hash((self.name, self.value)) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(name={self.name!r}, value={self.value!r})' def to_dict(self) -> ApplicationCommandOptionChoice: return { diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index 73f5e4b18..12dac76c9 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -26,7 +26,8 @@ from __future__ import annotations import inspect from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar, Union from .enums import AppCommandOptionType from .errors import TransformerError @@ -113,6 +114,13 @@ class CommandParameter: async def transform(self, interaction: Interaction, value: Any) -> Any: if hasattr(self._annotation, '__discord_app_commands_transformer__'): + # This one needs special handling for type safety reasons + if self._annotation.__discord_app_commands_is_choice__: + choice = next((c for c in self.choices if c.value == value), None) + if choice is None: + raise TransformerError(value, self.type, self._annotation) + return choice + return await self._annotation.transform(interaction, value) return value @@ -149,6 +157,7 @@ class Transformer: """ __discord_app_commands_transformer__: ClassVar[bool] = True + __discord_app_commands_is_choice__: ClassVar[bool] = False @classmethod def type(cls) -> AppCommandOptionType: @@ -221,24 +230,93 @@ class _TransformMetadata: self.metadata: Type[Transformer] = metadata +async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any: + return value + + def _make_range_transformer( opt_type: AppCommandOptionType, *, min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None, ) -> Type[Transformer]: - async def transform(cls, interaction: Interaction, value: Any) -> Any: - return value - ns = { 'type': classmethod(lambda _: opt_type), 'min_value': classmethod(lambda _: min), 'max_value': classmethod(lambda _: max), - 'transform': classmethod(transform), + 'transform': classmethod(_identity_transform), } return type('RangeTransformer', (Transformer,), ns) +def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]: + if len(values) < 2: + raise TypeError(f'typing.Literal requires at least two values.') + + first = type(values[0]) + if first is int: + opt_type = AppCommandOptionType.integer + elif first is float: + opt_type = AppCommandOptionType.number + elif first is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, str, or float values not {first!r}') + + ns = { + 'type': classmethod(lambda _: opt_type), + 'transform': classmethod(_identity_transform), + '__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values], + } + return type('LiteralTransformer', (Transformer,), ns) + + +def _make_choice_transformer(inner_type: Any) -> Type[Transformer]: + if inner_type is int: + opt_type = AppCommandOptionType.integer + elif inner_type is float: + opt_type = AppCommandOptionType.number + elif inner_type is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, str, or float values not {inner_type!r}') + + ns = { + 'type': classmethod(lambda _: opt_type), + 'transform': classmethod(_identity_transform), + '__discord_app_commands_is_choice__': True, + } + return type('ChoiceTransformer', (Transformer,), ns) + + +def _make_enum_transformer(enum) -> Type[Transformer]: + values = list(enum) + if len(values) < 2: + raise TypeError(f'enum.Enum requires at least two values.') + + first = type(values[0].value) + if first is int: + opt_type = AppCommandOptionType.integer + elif first is float: + opt_type = AppCommandOptionType.number + elif first is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, str, or float values not {first!r}') + + async def transform(cls, interaction: Interaction, value: Any) -> Any: + return enum(value) + + ns = { + 'type': classmethod(lambda _: opt_type), + 'transform': classmethod(transform), + '__discord_app_commands_transformer_enum__': enum, + '__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values], + } + + return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns) + + if TYPE_CHECKING: from typing_extensions import Annotated as Transform from typing_extensions import Annotated as Range @@ -465,11 +543,24 @@ def get_supported_annotation( if hasattr(annotation, '__discord_app_commands_transform__'): return (annotation.metadata, MISSING) - if inspect.isclass(annotation) and issubclass(annotation, Transformer): - return (annotation, MISSING) + if inspect.isclass(annotation): + if issubclass(annotation, Transformer): + return (annotation, MISSING) + if issubclass(annotation, Enum): + return (_make_enum_transformer(annotation), MISSING) + if annotation is Choice: + raise TypeError(f'Choice requires a type argument of int, str, or float') # Check if there's an origin origin = getattr(annotation, '__origin__', None) + if origin is Literal: + args = annotation.__args__ # type: ignore + return (_make_literal_transformer(args), MISSING) + + if origin is Choice: + arg = annotation.__args__[0] # type: ignore + return (_make_choice_transformer(arg), MISSING) + if origin is not Union: # Only Union/Optional is supported right now so bail early raise TypeError(f'unsupported type annotation {annotation!r}') @@ -522,9 +613,11 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co # Verify validity of the default parameter if default is not MISSING: - valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) - if not isinstance(default, valid_types): - raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') + enum_type = getattr(inner, '__discord_app_commands_transformer_enum__', None) + if default.__class__ is not enum_type: + valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) + if not isinstance(default, valid_types): + raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') result = CommandParameter( type=type, @@ -534,6 +627,13 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co name=parameter.name, ) + try: + choices = inner.__discord_app_commands_transformer_choices__ + except AttributeError: + pass + else: + result.choices = choices + # These methods should be duck typed if type in (AppCommandOptionType.number, AppCommandOptionType.integer): result.min_value = inner.min_value()