Browse Source

Refactor transformers to use instances instead of classmethods

This should allow them to be easier to use for users without requiring
a lot of metaprogramming hackery if you want to involve state.
pull/8280/head
Rapptz 3 years ago
parent
commit
11618cd1ba
  1. 20
      discord/app_commands/errors.py
  2. 503
      discord/app_commands/transformers.py
  3. 3
      discord/ext/commands/core.py
  4. 63
      discord/ext/commands/hybrid.py

20
discord/app_commands/errors.py

@ -121,28 +121,16 @@ class TransformerError(AppCommandError):
The value that failed to convert. The value that failed to convert.
type: :class:`~discord.AppCommandOptionType` type: :class:`~discord.AppCommandOptionType`
The type of argument that failed to convert. The type of argument that failed to convert.
transformer: Type[:class:`Transformer`] transformer: :class:`Transformer`
The transformer that failed the conversion. The transformer that failed the conversion.
""" """
def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Type[Transformer]): def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Transformer):
self.value: Any = value self.value: Any = value
self.type: AppCommandOptionType = opt_type self.type: AppCommandOptionType = opt_type
self.transformer: Type[Transformer] = transformer self.transformer: Transformer = transformer
try:
result_type = transformer.transform.__annotations__['return']
except KeyError:
name = transformer.__name__
if name.endswith('Transformer'):
result_type = name[:-11]
else:
result_type = name
else:
if isinstance(result_type, type):
result_type = result_type.__name__
super().__init__(f'Failed to convert {value} to {result_type!s}') super().__init__(f'Failed to convert {value} to {transformer._error_display_name!s}')
class CheckFailure(AppCommandError): class CheckFailure(AppCommandError):

503
discord/app_commands/transformers.py

@ -166,7 +166,7 @@ class Transformer:
to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one
from this type. from this type.
This class is customisable through the overriding of :func:`classmethod` in the class This class is customisable through the overriding of methods and properties in the class
and by using it as the second type parameter of the :class:`~discord.app_commands.Transform` and by using it as the second type parameter of the :class:`~discord.app_commands.Transform`
class. For example, to convert a string into a custom pair type: class. For example, to convert a string into a custom pair type:
@ -177,8 +177,7 @@ class Transformer:
y: int y: int
class PointTransformer(app_commands.Transformer): class PointTransformer(app_commands.Transformer):
@classmethod async def transform(self, interaction: discord.Interaction, value: str) -> Point:
async def transform(cls, interaction: discord.Interaction, value: str) -> Point:
(x, _, y) = value.partition(',') (x, _, y) = value.partition(',')
return Point(x=int(x.strip()), y=int(y.strip())) return Point(x=int(x.strip()), y=int(y.strip()))
@ -189,56 +188,90 @@ class Transformer:
): ):
await interaction.response.send_message(str(point)) await interaction.response.send_message(str(point))
If a class is passed instead of an instance to the second type parameter, then it is
constructed with no arguments passed to the ``__init__`` method.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__discord_app_commands_transformer__: ClassVar[bool] = True __discord_app_commands_transformer__: ClassVar[bool] = True
__discord_app_commands_is_choice__: ClassVar[bool] = False __discord_app_commands_is_choice__: ClassVar[bool] = False
@classmethod # This is needed to pass typing's type checks.
def type(cls) -> AppCommandOptionType: # e.g. Optional[MyTransformer]
def __call__(self) -> None:
pass
@property
def type(self) -> AppCommandOptionType:
""":class:`~discord.AppCommandOptionType`: The option type associated with this transformer. """:class:`~discord.AppCommandOptionType`: The option type associated with this transformer.
This must be a :obj:`classmethod`. This must be a :obj:`property`.
Defaults to :attr:`~discord.AppCommandOptionType.string`. Defaults to :attr:`~discord.AppCommandOptionType.string`.
""" """
return AppCommandOptionType.string return AppCommandOptionType.string
@classmethod @property
def channel_types(cls) -> List[ChannelType]: def channel_types(self) -> List[ChannelType]:
"""List[:class:`~discord.ChannelType`]: A list of channel types that are allowed to this parameter. """List[:class:`~discord.ChannelType`]: A list of channel types that are allowed to this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.channel`. Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.channel`.
This must be a :obj:`property`.
Defaults to an empty list. Defaults to an empty list.
""" """
return [] return []
@classmethod @property
def min_value(cls) -> Optional[Union[int, float]]: def min_value(self) -> Optional[Union[int, float]]:
"""Optional[:class:`int`]: The minimum supported value for this parameter. """Optional[:class:`int`]: The minimum supported value for this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``. Defaults to ``None``.
""" """
return None return None
@classmethod @property
def max_value(cls) -> Optional[Union[int, float]]: def max_value(self) -> Optional[Union[int, float]]:
"""Optional[:class:`int`]: The maximum supported value for this parameter. """Optional[:class:`int`]: The maximum supported value for this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``. Defaults to ``None``.
""" """
return None return None
@classmethod @property
async def transform(cls, interaction: Interaction, value: Any) -> Any: def choices(self) -> Optional[List[Choice[Union[int, float, str]]]]:
"""Optional[List[:class:`~discord.app_commands.Choice`]]: A list of choices that are allowed to this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``.
"""
return None
@property
def _error_display_name(self) -> str:
name = self.__class__.__name__
if name.endswith('Transformer'):
return name[:-11]
else:
return name
async def transform(self, interaction: Interaction, value: Any) -> Any:
"""|maybecoro| """|maybecoro|
Transforms the converted option value into another value. Transforms the converted option value into another value.
@ -257,9 +290,8 @@ class Transformer:
""" """
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError('Derived classes need to implement this.')
@classmethod
async def autocomplete( async def autocomplete(
cls, interaction: Interaction, value: Union[int, float, str] self, interaction: Interaction, value: Union[int, float, str]
) -> List[Choice[Union[int, float, str]]]: ) -> List[Choice[Union[int, float, str]]]:
"""|coro| """|coro|
@ -287,122 +319,141 @@ class Transformer:
raise NotImplementedError('Derived classes can implement this.') raise NotImplementedError('Derived classes can implement this.')
class _TransformMetadata: class IdentityTransformer(Transformer):
__discord_app_commands_transform__: ClassVar[bool] = True def __init__(self, type: AppCommandOptionType) -> None:
__slots__ = ('metadata',) self._type = type
def __init__(self, metadata: Type[Transformer]): @property
self.metadata: Type[Transformer] = metadata def type(self) -> AppCommandOptionType:
return self._type
# This is needed to pass typing's type checks. async def transform(self, interaction: Interaction, value: Any) -> Any:
# e.g. Optional[Transform[discord.Member, MyTransformer]] return value
def __call__(self) -> None:
pass
async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any: class RangeTransformer(IdentityTransformer):
return value def __init__(
self,
opt_type: AppCommandOptionType,
*,
min: Optional[Union[int, float]] = None,
max: Optional[Union[int, float]] = None,
) -> None:
if min and max and min > max:
raise TypeError('minimum cannot be larger than maximum')
self._min: Optional[Union[int, float]] = min
self._max: Optional[Union[int, float]] = max
super().__init__(opt_type)
@property
def min_value(self) -> Optional[Union[int, float]]:
return self._min
@property
def max_value(self) -> Optional[Union[int, float]]:
return self._max
class LiteralTransformer(IdentityTransformer):
def __init__(self, values: Tuple[Any, ...]) -> None:
first = type(values[0])
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
self._choices = [Choice(name=str(v), value=v) for v in values]
super().__init__(opt_type)
@property
def choices(self):
return self._choices
class ChoiceTransformer(IdentityTransformer):
__discord_app_commands_is_choice__: ClassVar[bool] = True
def __init__(self, inner_type: Any) -> None:
if inner_type is int:
opt_type = AppCommandOptionType.integer
elif inner_type is float:
opt_type = AppCommandOptionType.number
elif inner_type is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {inner_type!r}')
super().__init__(opt_type)
def _make_range_transformer(
opt_type: AppCommandOptionType, class EnumValueTransformer(Transformer):
*, def __init__(self, enum: Any) -> None:
min: Optional[Union[int, float]] = None, super().__init__()
max: Optional[Union[int, float]] = None,
) -> Type[Transformer]: values = list(enum)
if min and max and min > max: if len(values) < 2:
raise TypeError('minimum cannot be larger than maximum') raise TypeError(f'enum.Enum requires at least two values.')
ns = { first = type(values[0].value)
'type': classmethod(lambda _: opt_type), if first is int:
'min_value': classmethod(lambda _: min), opt_type = AppCommandOptionType.integer
'max_value': classmethod(lambda _: max), elif first is float:
'transform': classmethod(_identity_transform), opt_type = AppCommandOptionType.number
} elif first is str:
return type('RangeTransformer', (Transformer,), ns) opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]:
first = type(values[0]) self._type: AppCommandOptionType = opt_type
if first is int: self._enum: Any = enum
opt_type = AppCommandOptionType.integer self._choices = [Choice(name=v.name, value=v.value) for v in values]
elif first is float:
opt_type = AppCommandOptionType.number @property
elif first is str: def _error_display_name(self) -> str:
opt_type = AppCommandOptionType.string return self._enum.__name__
else:
raise TypeError(f'expected int, str, or float values not {first!r}') @property
def type(self) -> AppCommandOptionType:
ns = { return self._type
'type': classmethod(lambda _: opt_type),
'transform': classmethod(_identity_transform), @property
'__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values], def choices(self):
} return self._choices
return type('LiteralTransformer', (Transformer,), ns)
async def transform(self, interaction: Interaction, value: Any) -> Any:
return self._enum(value)
def _make_choice_transformer(inner_type: Any) -> Type[Transformer]:
if inner_type is int:
opt_type = AppCommandOptionType.integer class EnumNameTransformer(Transformer):
elif inner_type is float: def __init__(self, enum: Any) -> None:
opt_type = AppCommandOptionType.number super().__init__()
elif inner_type is str:
opt_type = AppCommandOptionType.string values = list(enum)
else: if len(values) < 2:
raise TypeError(f'expected int, str, or float values not {inner_type!r}') raise TypeError(f'enum.Enum requires at least two values.')
ns = { self._enum: Any = enum
'type': classmethod(lambda _: opt_type), self._choices = [Choice(name=v.name, value=v.value) for v in values]
'transform': classmethod(_identity_transform),
'__discord_app_commands_is_choice__': True, @property
} def _error_display_name(self) -> str:
return type('ChoiceTransformer', (Transformer,), ns) return self._enum.__name__
@property
def _make_enum_transformer(enum) -> Type[Transformer]: def type(self) -> AppCommandOptionType:
values = list(enum) return AppCommandOptionType.string
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.') @property
def choices(self):
first = type(values[0].value) return self._choices
if first is int:
opt_type = AppCommandOptionType.integer async def transform(self, interaction: Interaction, value: Any) -> Any:
elif first is float: return self._enum[value]
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return enum(value)
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(transform),
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values],
}
return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns)
def _make_complex_enum_transformer(enum) -> Type[Transformer]:
values = list(enum)
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.')
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return enum[value]
ns = {
'type': classmethod(lambda _: AppCommandOptionType.string),
'transform': classmethod(transform),
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.name) for v in values],
}
return type(f'{enum.__name__}ComplexEnumTransformer', (Transformer,), ns)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -433,11 +484,14 @@ else:
_, transformer = items _, transformer = items
is_valid = inspect.isclass(transformer) and issubclass(transformer, Transformer) if inspect.isclass(transformer):
if not is_valid: if not issubclass(transformer, Transformer):
raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}') raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}')
transformer = transformer()
elif not isinstance(transformer, Transformer):
raise TypeError(f'second argument of Transform must be a Transformer not {transformer.__class__!r}')
return _TransformMetadata(transformer) return transformer
class Range: class Range:
"""A type annotation that can be applied to a parameter to require a numeric or string """A type annotation that can be applied to a parameter to require a numeric or string
@ -497,88 +551,80 @@ else:
else: else:
cast = float cast = float
transformer = _make_range_transformer( transformer = RangeTransformer(
opt_type, opt_type,
min=cast(min) if min is not None else None, min=cast(min) if min is not None else None,
max=cast(max) if max is not None else None, max=cast(max) if max is not None else None,
) )
return _TransformMetadata(transformer) return transformer
def passthrough_transformer(opt_type: AppCommandOptionType) -> Type[Transformer]:
class _Generated(Transformer):
@classmethod
def type(cls) -> AppCommandOptionType:
return opt_type
@classmethod
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return value
return _Generated
class MemberTransformer(Transformer): class MemberTransformer(Transformer):
@classmethod @property
def type(cls) -> AppCommandOptionType: def type(self) -> AppCommandOptionType:
return AppCommandOptionType.user return AppCommandOptionType.user
@classmethod async def transform(self, interaction: Interaction, value: Any) -> Member:
async def transform(cls, interaction: Interaction, value: Any) -> Member:
if not isinstance(value, Member): if not isinstance(value, Member):
raise TransformerError(value, cls.type(), cls) raise TransformerError(value, self.type, self)
return value return value
def channel_transformer(*channel_types: Type[Any], raw: Optional[bool] = False) -> Type[Transformer]: class BaseChannelTransformer(Transformer):
if raw: def __init__(self, *channel_types: Type[Any]) -> None:
super().__init__()
if len(channel_types) == 1:
display_name = channel_types[0].__name__
types = CHANNEL_TO_TYPES[channel_types[0]]
else:
display_name = '{}, and {}'.format(', '.join(t.__name__ for t in channel_types[:-1]), channel_types[-1].__name__)
types = []
async def transform(cls, interaction: Interaction, value: Any): for t in channel_types:
if not isinstance(value, channel_types): try:
raise TransformerError(value, AppCommandOptionType.channel, cls) types.extend(CHANNEL_TO_TYPES[t])
return value except KeyError:
raise TypeError(f'Union type of channels must be entirely made up of channels') from None
elif raw is False: self._types: Tuple[Type[Any]] = channel_types
self._channel_types: List[ChannelType] = types
self._display_name = display_name
async def transform(cls, interaction: Interaction, value: Any): @property
resolved = value.resolve() def _error_display_name(self) -> str:
if resolved is None or not isinstance(resolved, channel_types): return self._display_name
raise TransformerError(value, AppCommandOptionType.channel, cls)
return resolved
else: @property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.channel
async def transform(cls, interaction: Interaction, value: Any): @property
if isinstance(value, channel_types): def channel_types(self) -> List[ChannelType]:
return value return self._channel_types
resolved = value.resolve() async def transform(self, interaction: Interaction, value: Any):
if resolved is None or not isinstance(resolved, channel_types): resolved = value.resolve()
raise TransformerError(value, AppCommandOptionType.channel, cls) if resolved is None or not isinstance(resolved, self._types):
return resolved raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
if len(channel_types) == 1:
name = channel_types[0].__name__
types = CHANNEL_TO_TYPES[channel_types[0]]
else:
name = 'MultiChannel'
types = []
for t in channel_types: class RawChannelTransformer(BaseChannelTransformer):
try: async def transform(self, interaction: Interaction, value: Any):
types.extend(CHANNEL_TO_TYPES[t]) if not isinstance(value, self._types):
except KeyError: raise TransformerError(value, AppCommandOptionType.channel, self)
raise TypeError(f'Union type of channels must be entirely made up of channels') from None return value
return type(
f'{name}Transformer', class UnionChannelTransformer(BaseChannelTransformer):
(Transformer,), async def transform(self, interaction: Interaction, value: Any):
{ if isinstance(value, self._types):
'type': classmethod(lambda cls: AppCommandOptionType.channel), return value
'transform': classmethod(transform),
'channel_types': classmethod(lambda cls: types), resolved = value.resolve()
}, if resolved is None or not isinstance(resolved, self._types):
) raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = { CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
@ -604,23 +650,23 @@ CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
CategoryChannel: [ChannelType.category], CategoryChannel: [ChannelType.category],
} }
BUILT_IN_TRANSFORMERS: Dict[Any, Type[Transformer]] = { BUILT_IN_TRANSFORMERS: Dict[Any, Transformer] = {
str: passthrough_transformer(AppCommandOptionType.string), str: IdentityTransformer(AppCommandOptionType.string),
int: passthrough_transformer(AppCommandOptionType.integer), int: IdentityTransformer(AppCommandOptionType.integer),
float: passthrough_transformer(AppCommandOptionType.number), float: IdentityTransformer(AppCommandOptionType.number),
bool: passthrough_transformer(AppCommandOptionType.boolean), bool: IdentityTransformer(AppCommandOptionType.boolean),
User: passthrough_transformer(AppCommandOptionType.user), User: IdentityTransformer(AppCommandOptionType.user),
Member: MemberTransformer, Member: MemberTransformer(),
Role: passthrough_transformer(AppCommandOptionType.role), Role: IdentityTransformer(AppCommandOptionType.role),
AppCommandChannel: channel_transformer(AppCommandChannel, raw=True), AppCommandChannel: RawChannelTransformer(AppCommandChannel),
AppCommandThread: channel_transformer(AppCommandThread, raw=True), AppCommandThread: RawChannelTransformer(AppCommandThread),
GuildChannel: channel_transformer(GuildChannel), GuildChannel: BaseChannelTransformer(GuildChannel),
Thread: channel_transformer(Thread), Thread: BaseChannelTransformer(Thread),
StageChannel: channel_transformer(StageChannel), StageChannel: BaseChannelTransformer(StageChannel),
VoiceChannel: channel_transformer(VoiceChannel), VoiceChannel: BaseChannelTransformer(VoiceChannel),
TextChannel: channel_transformer(TextChannel), TextChannel: BaseChannelTransformer(TextChannel),
CategoryChannel: channel_transformer(CategoryChannel), CategoryChannel: BaseChannelTransformer(CategoryChannel),
Attachment: passthrough_transformer(AppCommandOptionType.attachment), Attachment: IdentityTransformer(AppCommandOptionType.attachment),
} }
ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = { ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
@ -635,7 +681,7 @@ def get_supported_annotation(
annotation: Any, annotation: Any,
*, *,
_none: type = NoneType, _none: type = NoneType,
_mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS, _mapping: Dict[Any, Transformer] = BUILT_IN_TRANSFORMERS,
) -> Tuple[Any, Any, bool]: ) -> Tuple[Any, Any, bool]:
"""Returns an appropriate, yet supported, annotation along with an optional default value. """Returns an appropriate, yet supported, annotation along with an optional default value.
@ -650,20 +696,20 @@ def get_supported_annotation(
except KeyError: except KeyError:
pass pass
if hasattr(annotation, '__discord_app_commands_transform__'): if isinstance(annotation, Transformer):
return (annotation.metadata, MISSING, False) return (annotation, MISSING, False)
if hasattr(annotation, '__metadata__'): if hasattr(annotation, '__metadata__'):
return get_supported_annotation(annotation.__metadata__[0]) return get_supported_annotation(annotation.__metadata__[0])
if inspect.isclass(annotation): if inspect.isclass(annotation):
if issubclass(annotation, Transformer): if issubclass(annotation, Transformer):
return (annotation, MISSING, False) return (annotation(), MISSING, False)
if issubclass(annotation, (Enum, InternalEnum)): if issubclass(annotation, (Enum, InternalEnum)):
if all(isinstance(v.value, (str, int, float)) for v in annotation): if all(isinstance(v.value, (str, int, float)) for v in annotation):
return (_make_enum_transformer(annotation), MISSING, False) return (EnumValueTransformer(annotation), MISSING, False)
else: else:
return (_make_complex_enum_transformer(annotation), MISSING, False) return (EnumNameTransformer(annotation), MISSING, False)
if annotation is Choice: if annotation is Choice:
raise TypeError(f'Choice requires a type argument of int, str, or float') raise TypeError(f'Choice requires a type argument of int, str, or float')
@ -671,11 +717,11 @@ def get_supported_annotation(
origin = getattr(annotation, '__origin__', None) origin = getattr(annotation, '__origin__', None)
if origin is Literal: if origin is Literal:
args = annotation.__args__ # type: ignore args = annotation.__args__ # type: ignore
return (_make_literal_transformer(args), MISSING, True) return (LiteralTransformer(args), MISSING, True)
if origin is Choice: if origin is Choice:
arg = annotation.__args__[0] # type: ignore arg = annotation.__args__[0] # type: ignore
return (_make_choice_transformer(arg), MISSING, True) return (ChoiceTransformer(arg), MISSING, True)
if origin is not Union: if origin is not Union:
# Only Union/Optional is supported right now so bail early # Only Union/Optional is supported right now so bail early
@ -697,7 +743,7 @@ def get_supported_annotation(
# Check for channel union types # Check for channel union types
if any(arg in CHANNEL_TO_TYPES for arg in args): if any(arg in CHANNEL_TO_TYPES for arg in args):
# If any channel type is given, then *all* must be channel types # If any channel type is given, then *all* must be channel types
return (channel_transformer(*args, raw=None), default, True) return (UnionChannelTransformer(*args), default, True)
# The only valid transformations here are: # The only valid transformations here are:
# [Member, User] => user # [Member, User] => user
@ -707,9 +753,9 @@ def get_supported_annotation(
if not all(arg in supported_types for arg in args): if not all(arg in supported_types for arg in args):
raise TypeError(f'unsupported types given inside {annotation!r}') raise TypeError(f'unsupported types given inside {annotation!r}')
if args == (User, Member) or args == (Member, User): if args == (User, Member) or args == (Member, User):
return (passthrough_transformer(AppCommandOptionType.user), default, True) return (IdentityTransformer(AppCommandOptionType.user), default, True)
return (passthrough_transformer(AppCommandOptionType.mentionable), default, True) return (IdentityTransformer(AppCommandOptionType.mentionable), default, True)
def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter: def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter:
@ -721,7 +767,7 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
""" """
(inner, default, validate_default) = get_supported_annotation(annotation) (inner, default, validate_default) = get_supported_annotation(annotation)
type = inner.type() type = inner.type
if default is MISSING or default is None: if default is MISSING or default is None:
param_default = parameter.default param_default = parameter.default
@ -742,26 +788,23 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
name=parameter.name, name=parameter.name,
) )
try: choices = inner.choices
choices = inner.__discord_app_commands_transformer_choices__ if choices is not None:
except AttributeError:
pass
else:
result.choices = choices result.choices = choices
# These methods should be duck typed # These methods should be duck typed
if type in (AppCommandOptionType.number, AppCommandOptionType.string, AppCommandOptionType.integer): if type in (AppCommandOptionType.number, AppCommandOptionType.string, AppCommandOptionType.integer):
result.min_value = inner.min_value() result.min_value = inner.min_value
result.max_value = inner.max_value() result.max_value = inner.max_value
if type is AppCommandOptionType.channel: if type is AppCommandOptionType.channel:
result.channel_types = inner.channel_types() result.channel_types = inner.channel_types
if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL): if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL):
raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}') raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}')
autocomplete_func = getattr(inner.autocomplete, '__func__', inner.autocomplete) # Check if the method is overridden
if autocomplete_func is not Transformer.autocomplete.__func__: if inner.autocomplete.__func__ is not Transformer.autocomplete:
from .commands import _validate_auto_complete_callback from .commands import _validate_auto_complete_callback
result.autocomplete = _validate_auto_complete_callback(inner.autocomplete, skip_binding=True) result.autocomplete = _validate_auto_complete_callback(inner.autocomplete, skip_binding=True)

3
discord/ext/commands/core.py

@ -165,9 +165,6 @@ def get_signature_parameters(
if len(metadata) >= 1: if len(metadata) >= 1:
annotation = metadata[0] annotation = metadata[0]
if isinstance(annotation, discord.app_commands.transformers._TransformMetadata):
annotation = annotation.metadata
params[name] = parameter.replace(annotation=annotation) params[name] = parameter.replace(annotation=annotation)
return params return params

63
discord/ext/commands/hybrid.py

@ -116,18 +116,24 @@ def required_pos_arguments(func: Callable[..., Any]) -> int:
return sum(p.default is p.empty for p in sig.parameters.values()) return sum(p.default is p.empty for p in sig.parameters.values())
def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app_commands.Transformer]: class ConverterTransformer(app_commands.Transformer):
try: def __init__(self, converter: Any, parameter: Parameter) -> None:
module = converter.__module__ super().__init__()
except AttributeError: self.converter: Any = converter
pass self.parameter: Parameter = parameter
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')): try:
converter = CONVERTER_MAPPING.get(converter, converter) module = converter.__module__
except AttributeError:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any: pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
self.converter = CONVERTER_MAPPING.get(converter, converter)
async def transform(self, interaction: discord.Interaction, value: str) -> Any:
ctx = interaction._baton ctx = interaction._baton
ctx.current_parameter = parameter converter = self.converter
ctx.current_parameter = self.parameter
ctx.current_argument = value ctx.current_argument = value
try: try:
if inspect.isclass(converter) and issubclass(converter, Converter): if inspect.isclass(converter) and issubclass(converter, Converter):
@ -142,27 +148,34 @@ def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app
except Exception as exc: except Exception as exc:
raise ConversionError(converter, exc) from exc # type: ignore raise ConversionError(converter, exc) from exc # type: ignore
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
class CallableTransformer(app_commands.Transformer):
def __init__(self, func: Callable[[str], Any]) -> None:
super().__init__()
self.func: Callable[[str], Any] = func
def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]: async def transform(self, interaction: discord.Interaction, value: str) -> Any:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
try: try:
return func(value) return self.func(value)
except CommandError: except CommandError:
raise raise
except Exception as exc: except Exception as exc:
raise BadArgument(f'Converting to "{func.__name__}" failed') from exc raise BadArgument(f'Converting to "{self.func.__name__}" failed') from exc
return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
class GreedyTransformer(app_commands.Transformer):
def __init__(self, converter: Any, parameter: Parameter) -> None:
super().__init__()
self.converter: Any = converter
self.parameter: Parameter = parameter
def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_commands.Transformer]: async def transform(self, interaction: discord.Interaction, value: str) -> Any:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
view = StringView(value) view = StringView(value)
result = [] result = []
ctx = interaction._baton ctx = interaction._baton
ctx.current_parameter = parameter ctx.current_parameter = parameter = self.parameter
converter = self.converter
while True: while True:
view.skip_ws() view.skip_ws()
ctx.current_argument = arg = view.get_quoted_word() ctx.current_argument = arg = view.get_quoted_word()
@ -175,8 +188,6 @@ def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_co
return result return result
return type('GreedyTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def replace_parameter( def replace_parameter(
param: inspect.Parameter, param: inspect.Parameter,
@ -203,7 +214,7 @@ def replace_parameter(
if inner is discord.Attachment: if inner is discord.Attachment:
raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands') raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands')
param = param.replace(annotation=make_greedy_transformer(inner, original)) param = param.replace(annotation=GreedyTransformer(inner, original))
elif is_flag(converter): elif is_flag(converter):
callback.__hybrid_command_flag__ = (param.name, converter) callback.__hybrid_command_flag__ = (param.name, converter)
descriptions = {} descriptions = {}
@ -233,14 +244,14 @@ def replace_parameter(
app_commands.rename(**renames)(callback) app_commands.rename(**renames)(callback)
elif is_converter(converter) or converter in CONVERTER_MAPPING: elif is_converter(converter) or converter in CONVERTER_MAPPING:
param = param.replace(annotation=make_converter_transformer(converter, original)) param = param.replace(annotation=ConverterTransformer(converter, original))
elif origin is Union: elif origin is Union:
if len(args) == 2 and args[-1] is _NoneType: if len(args) == 2 and args[-1] is _NoneType:
# Special case Optional[X] where X is a single type that can optionally be a converter # Special case Optional[X] where X is a single type that can optionally be a converter
inner = args[0] inner = args[0]
is_inner_transformer = is_transformer(inner) is_inner_transformer = is_transformer(inner)
if is_converter(inner) and not is_inner_transformer: if is_converter(inner) and not is_inner_transformer:
param = param.replace(annotation=Optional[make_converter_transformer(inner, original)]) # type: ignore param = param.replace(annotation=Optional[ConverterTransformer(inner, original)]) # type: ignore
else: else:
raise raise
elif origin: elif origin:
@ -250,7 +261,7 @@ def replace_parameter(
param_count = required_pos_arguments(converter) param_count = required_pos_arguments(converter)
if param_count != 1: if param_count != 1:
raise raise
param = param.replace(annotation=make_callable_transformer(converter)) param = param.replace(annotation=CallableTransformer(converter))
return param return param

Loading…
Cancel
Save