Browse Source

Add support for choice option parameters

This implements it in three different ways:

* The first is using typing.Literal for quick and easy ones
* The second is using enum.Enum for slightly more complex ones
* The last is using a Choice type hint with a decorator to pass
  a list of choices.

This should hopefully cover most use cases.
pull/7492/head
Rapptz 3 years ago
parent
commit
4e04dbdec7
  1. 125
      discord/app_commands/commands.py
  2. 26
      discord/app_commands/models.py
  3. 120
      discord/app_commands/transformers.py

125
discord/app_commands/commands.py

@ -77,6 +77,7 @@ __all__ = (
'Group', 'Group',
'command', 'command',
'describe', 'describe',
'choices',
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -171,6 +172,31 @@ def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Di
raise TypeError(f'unknown parameter given: {first}') raise TypeError(f'unknown parameter given: {first}')
def _populate_choices(params: Dict[str, CommandParameter], all_choices: Dict[str, List[Choice]]) -> None:
for name, param in params.items():
choices = all_choices.pop(name, MISSING)
if choices is MISSING:
continue
if not isinstance(choices, list):
raise TypeError('choices must be a list of Choice')
if not all(isinstance(choice, Choice) for choice in choices):
raise TypeError('choices must be a list of Choice')
if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer):
raise TypeError('choices are only supported for integer, string, or number option types')
# There's a type safety hole if someone does Choice[float] as an annotation
# but the values are actually Choice[int]. Since the input-output is the same this feels
# safe enough to ignore.
param.choices = choices
if all_choices:
first = next(iter(all_choices))
raise TypeError(f'unknown parameter given: {first}')
def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]: def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]:
params = inspect.signature(func).parameters params = inspect.signature(func).parameters
cache = {} cache = {}
@ -203,6 +229,13 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s
else: else:
_populate_descriptions(result, descriptions) _populate_descriptions(result, descriptions)
try:
choices = func.__discord_app_commands_param_choices__
except AttributeError:
pass
else:
_populate_choices(result, choices)
return result return result
@ -313,15 +346,15 @@ class Command(Generic[GroupT, P, T]):
async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T: async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T:
values = namespace.__dict__ values = namespace.__dict__
for name, param in self._params.items(): for name, param in self._params.items():
if not param.required: try:
values.setdefault(name, param.default) value = values[name]
else: except KeyError:
try: if not param.required:
value = values[name] values[name] = param.default
except KeyError:
raise CommandSignatureMismatch(self) from None
else: else:
values[name] = await param.transform(interaction, value) raise CommandSignatureMismatch(self) from None
else:
values[name] = await param.transform(interaction, value)
# These type ignores are because the type checker doesn't quite understand the narrowing here # These type ignores are because the type checker doesn't quite understand the narrowing here
# Likewise, it thinks we're missing positional arguments when there aren't any. # Likewise, it thinks we're missing positional arguments when there aren't any.
@ -768,7 +801,7 @@ def describe(**parameters: str) -> Callable[[T], T]:
.. code-block:: python3 .. code-block:: python3
@app_commands.command() @app_commands.command()
@app_commads.describe(member='the member to ban') @app_commands.describe(member='the member to ban')
async def ban(interaction: discord.Interaction, member: discord.Member): async def ban(interaction: discord.Interaction, member: discord.Member):
await interaction.response.send_message(f'Banned {member}') await interaction.response.send_message(f'Banned {member}')
@ -787,7 +820,79 @@ def describe(**parameters: str) -> Callable[[T], T]:
if isinstance(inner, Command): if isinstance(inner, Command):
_populate_descriptions(inner._params, parameters) _populate_descriptions(inner._params, parameters)
else: else:
inner.__discord_app_commands_param_description__ = parameters # type: ignore - Runtime attribute assignment try:
inner.__discord_app_commands_param_description__.update(parameters) # type: ignore - Runtime attribute access
except AttributeError:
inner.__discord_app_commands_param_description__ = parameters # type: ignore - Runtime attribute assignment
return inner
return decorator
def choices(**parameters: List[Choice]) -> Callable[[T], T]:
r"""Instructs the given parameters by their name to use the given choices for their choices.
Example:
.. code-block:: python3
@app_commands.command()
@app_commands.describe(fruits='fruits to choose from')
@app_commands.choices(fruits=[
Choice(name='apple', value=1),
Choice(name='banana', value=2),
Choice(name='cherry', value=3),
])
async def fruit(interaction: discord.Interaction, fruits: Choice[int]):
await interaction.response.send_message(f'Your favourite fruit is {fruits.name}.')
.. note::
This is not the only way to provide choices to a command. There are two more ergonomic ways
of doing this. The first one is to use a :obj:`typing.Literal` annotation:
.. code-block:: python3
@app_commands.command()
@app_commands.describe(fruits='fruits to choose from')
async def fruit(interaction: discord.Interaction, fruits: Literal['apple', 'banana', 'cherry']):
await interaction.response.send_message(f'Your favourite fruit is {fruits}.')
The second way is to use an :class:`enum.Enum`:
.. code-block:: python3
class Fruits(enum.Enum):
apple = 1
banana = 2
cherry = 3
@app_commands.command()
@app_commands.describe(fruits='fruits to choose from')
async def fruit(interaction: discord.Interaction, fruits: Fruits):
await interaction.response.send_message(f'Your favourite fruit is {fruits}.')
Parameters
-----------
\*\*parameters
The choices of the parameters.
Raises
--------
TypeError
The parameter name is not found.
"""
def decorator(inner: T) -> T:
if isinstance(inner, Command):
_populate_choices(inner._params, parameters)
else:
try:
inner.__discord_app_commands_param_choices__.update(parameters) # type: ignore - Runtime attribute access
except AttributeError:
inner.__discord_app_commands_param_choices__ = parameters # type: ignore - Runtime attribute assignment
return inner return inner

26
discord/app_commands/models.py

@ -31,7 +31,7 @@ from ..enums import ChannelType, try_enum
from ..mixins import Hashable from ..mixins import Hashable
from ..utils import _get_as_snowflake, parse_time, snowflake_time from ..utils import _get_as_snowflake, parse_time, snowflake_time
from .enums import AppCommandOptionType, AppCommandType from .enums import AppCommandOptionType, AppCommandType
from typing import List, NamedTuple, TYPE_CHECKING, Optional, Union from typing import Generic, List, NamedTuple, TYPE_CHECKING, Optional, TypeVar, Union
__all__ = ( __all__ = (
'AppCommand', 'AppCommand',
@ -42,6 +42,8 @@ __all__ = (
'Choice', 'Choice',
) )
ChoiceT = TypeVar('ChoiceT', str, int, float, Union[str, int, float])
def is_app_command_argument_type(value: int) -> bool: def is_app_command_argument_type(value: int) -> bool:
return 11 >= value >= 3 return 11 >= value >= 3
@ -145,7 +147,7 @@ class AppCommand(Hashable):
return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>' return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>'
class Choice(NamedTuple): class Choice(Generic[ChoiceT]):
"""Represents an application command argument choice. """Represents an application command argument choice.
.. versionadded:: 2.0 .. versionadded:: 2.0
@ -160,6 +162,10 @@ class Choice(NamedTuple):
Checks if two choices are not equal. Checks if two choices are not equal.
.. describe:: hash(x)
Returns the choice's hash.
Parameters Parameters
----------- -----------
name: :class:`str` name: :class:`str`
@ -168,8 +174,20 @@ class Choice(NamedTuple):
The value of the choice. The value of the choice.
""" """
name: str __slots__ = ('name', 'value')
value: Union[int, str, float]
def __init__(self, *, name: str, value: ChoiceT):
self.name: str = name
self.value: ChoiceT = value
def __eq__(self, o: object) -> bool:
return isinstance(o, Choice) and self.name == o.name and self.value == o.value
def __hash__(self) -> int:
return hash((self.name, self.value))
def __repr__(self) -> str:
return f'{self.__class__.__name__}(name={self.name!r}, value={self.value!r})'
def to_dict(self) -> ApplicationCommandOptionChoice: def to_dict(self) -> ApplicationCommandOptionChoice:
return { return {

120
discord/app_commands/transformers.py

@ -26,7 +26,8 @@ from __future__ import annotations
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar, Union
from .enums import AppCommandOptionType from .enums import AppCommandOptionType
from .errors import TransformerError from .errors import TransformerError
@ -113,6 +114,13 @@ class CommandParameter:
async def transform(self, interaction: Interaction, value: Any) -> Any: async def transform(self, interaction: Interaction, value: Any) -> Any:
if hasattr(self._annotation, '__discord_app_commands_transformer__'): if hasattr(self._annotation, '__discord_app_commands_transformer__'):
# This one needs special handling for type safety reasons
if self._annotation.__discord_app_commands_is_choice__:
choice = next((c for c in self.choices if c.value == value), None)
if choice is None:
raise TransformerError(value, self.type, self._annotation)
return choice
return await self._annotation.transform(interaction, value) return await self._annotation.transform(interaction, value)
return value return value
@ -149,6 +157,7 @@ class Transformer:
""" """
__discord_app_commands_transformer__: ClassVar[bool] = True __discord_app_commands_transformer__: ClassVar[bool] = True
__discord_app_commands_is_choice__: ClassVar[bool] = False
@classmethod @classmethod
def type(cls) -> AppCommandOptionType: def type(cls) -> AppCommandOptionType:
@ -221,24 +230,93 @@ class _TransformMetadata:
self.metadata: Type[Transformer] = metadata self.metadata: Type[Transformer] = metadata
async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any:
return value
def _make_range_transformer( def _make_range_transformer(
opt_type: AppCommandOptionType, opt_type: AppCommandOptionType,
*, *,
min: Optional[Union[int, float]] = None, min: Optional[Union[int, float]] = None,
max: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None,
) -> Type[Transformer]: ) -> Type[Transformer]:
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return value
ns = { ns = {
'type': classmethod(lambda _: opt_type), 'type': classmethod(lambda _: opt_type),
'min_value': classmethod(lambda _: min), 'min_value': classmethod(lambda _: min),
'max_value': classmethod(lambda _: max), 'max_value': classmethod(lambda _: max),
'transform': classmethod(transform), 'transform': classmethod(_identity_transform),
} }
return type('RangeTransformer', (Transformer,), ns) return type('RangeTransformer', (Transformer,), ns)
def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]:
if len(values) < 2:
raise TypeError(f'typing.Literal requires at least two values.')
first = type(values[0])
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(_identity_transform),
'__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values],
}
return type('LiteralTransformer', (Transformer,), ns)
def _make_choice_transformer(inner_type: Any) -> Type[Transformer]:
if inner_type is int:
opt_type = AppCommandOptionType.integer
elif inner_type is float:
opt_type = AppCommandOptionType.number
elif inner_type is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {inner_type!r}')
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(_identity_transform),
'__discord_app_commands_is_choice__': True,
}
return type('ChoiceTransformer', (Transformer,), ns)
def _make_enum_transformer(enum) -> Type[Transformer]:
values = list(enum)
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.')
first = type(values[0].value)
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return enum(value)
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(transform),
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values],
}
return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Annotated as Transform from typing_extensions import Annotated as Transform
from typing_extensions import Annotated as Range from typing_extensions import Annotated as Range
@ -465,11 +543,24 @@ def get_supported_annotation(
if hasattr(annotation, '__discord_app_commands_transform__'): if hasattr(annotation, '__discord_app_commands_transform__'):
return (annotation.metadata, MISSING) return (annotation.metadata, MISSING)
if inspect.isclass(annotation) and issubclass(annotation, Transformer): if inspect.isclass(annotation):
return (annotation, MISSING) if issubclass(annotation, Transformer):
return (annotation, MISSING)
if issubclass(annotation, Enum):
return (_make_enum_transformer(annotation), MISSING)
if annotation is Choice:
raise TypeError(f'Choice requires a type argument of int, str, or float')
# Check if there's an origin # Check if there's an origin
origin = getattr(annotation, '__origin__', None) origin = getattr(annotation, '__origin__', None)
if origin is Literal:
args = annotation.__args__ # type: ignore
return (_make_literal_transformer(args), MISSING)
if origin is Choice:
arg = annotation.__args__[0] # type: ignore
return (_make_choice_transformer(arg), MISSING)
if origin is not Union: if origin is not Union:
# Only Union/Optional is supported right now so bail early # Only Union/Optional is supported right now so bail early
raise TypeError(f'unsupported type annotation {annotation!r}') raise TypeError(f'unsupported type annotation {annotation!r}')
@ -522,9 +613,11 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
# Verify validity of the default parameter # Verify validity of the default parameter
if default is not MISSING: if default is not MISSING:
valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) enum_type = getattr(inner, '__discord_app_commands_transformer_enum__', None)
if not isinstance(default, valid_types): if default.__class__ is not enum_type:
raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,))
if not isinstance(default, valid_types):
raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}')
result = CommandParameter( result = CommandParameter(
type=type, type=type,
@ -534,6 +627,13 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
name=parameter.name, name=parameter.name,
) )
try:
choices = inner.__discord_app_commands_transformer_choices__
except AttributeError:
pass
else:
result.choices = choices
# These methods should be duck typed # These methods should be duck typed
if type in (AppCommandOptionType.number, AppCommandOptionType.integer): if type in (AppCommandOptionType.number, AppCommandOptionType.integer):
result.min_value = inner.min_value() result.min_value = inner.min_value()

Loading…
Cancel
Save