Browse Source

Refactor transformers to use instances instead of classmethods

This should allow them to be easier to use for users without requiring
a lot of metaprogramming hackery if you want to involve state.
pull/8280/head
Rapptz 3 years ago
parent
commit
11618cd1ba
  1. 20
      discord/app_commands/errors.py
  2. 385
      discord/app_commands/transformers.py
  3. 3
      discord/ext/commands/core.py
  4. 49
      discord/ext/commands/hybrid.py

20
discord/app_commands/errors.py

@ -121,28 +121,16 @@ class TransformerError(AppCommandError):
The value that failed to convert. The value that failed to convert.
type: :class:`~discord.AppCommandOptionType` type: :class:`~discord.AppCommandOptionType`
The type of argument that failed to convert. The type of argument that failed to convert.
transformer: Type[:class:`Transformer`] transformer: :class:`Transformer`
The transformer that failed the conversion. The transformer that failed the conversion.
""" """
def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Type[Transformer]): def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Transformer):
self.value: Any = value self.value: Any = value
self.type: AppCommandOptionType = opt_type self.type: AppCommandOptionType = opt_type
self.transformer: Type[Transformer] = transformer self.transformer: Transformer = transformer
try:
result_type = transformer.transform.__annotations__['return']
except KeyError:
name = transformer.__name__
if name.endswith('Transformer'):
result_type = name[:-11]
else:
result_type = name
else:
if isinstance(result_type, type):
result_type = result_type.__name__
super().__init__(f'Failed to convert {value} to {result_type!s}') super().__init__(f'Failed to convert {value} to {transformer._error_display_name!s}')
class CheckFailure(AppCommandError): class CheckFailure(AppCommandError):

385
discord/app_commands/transformers.py

@ -166,7 +166,7 @@ class Transformer:
to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one
from this type. from this type.
This class is customisable through the overriding of :func:`classmethod` in the class This class is customisable through the overriding of methods and properties in the class
and by using it as the second type parameter of the :class:`~discord.app_commands.Transform` and by using it as the second type parameter of the :class:`~discord.app_commands.Transform`
class. For example, to convert a string into a custom pair type: class. For example, to convert a string into a custom pair type:
@ -177,8 +177,7 @@ class Transformer:
y: int y: int
class PointTransformer(app_commands.Transformer): class PointTransformer(app_commands.Transformer):
@classmethod async def transform(self, interaction: discord.Interaction, value: str) -> Point:
async def transform(cls, interaction: discord.Interaction, value: str) -> Point:
(x, _, y) = value.partition(',') (x, _, y) = value.partition(',')
return Point(x=int(x.strip()), y=int(y.strip())) return Point(x=int(x.strip()), y=int(y.strip()))
@ -189,56 +188,90 @@ class Transformer:
): ):
await interaction.response.send_message(str(point)) await interaction.response.send_message(str(point))
If a class is passed instead of an instance to the second type parameter, then it is
constructed with no arguments passed to the ``__init__`` method.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__discord_app_commands_transformer__: ClassVar[bool] = True __discord_app_commands_transformer__: ClassVar[bool] = True
__discord_app_commands_is_choice__: ClassVar[bool] = False __discord_app_commands_is_choice__: ClassVar[bool] = False
@classmethod # This is needed to pass typing's type checks.
def type(cls) -> AppCommandOptionType: # e.g. Optional[MyTransformer]
def __call__(self) -> None:
pass
@property
def type(self) -> AppCommandOptionType:
""":class:`~discord.AppCommandOptionType`: The option type associated with this transformer. """:class:`~discord.AppCommandOptionType`: The option type associated with this transformer.
This must be a :obj:`classmethod`. This must be a :obj:`property`.
Defaults to :attr:`~discord.AppCommandOptionType.string`. Defaults to :attr:`~discord.AppCommandOptionType.string`.
""" """
return AppCommandOptionType.string return AppCommandOptionType.string
@classmethod @property
def channel_types(cls) -> List[ChannelType]: def channel_types(self) -> List[ChannelType]:
"""List[:class:`~discord.ChannelType`]: A list of channel types that are allowed to this parameter. """List[:class:`~discord.ChannelType`]: A list of channel types that are allowed to this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.channel`. Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.channel`.
This must be a :obj:`property`.
Defaults to an empty list. Defaults to an empty list.
""" """
return [] return []
@classmethod @property
def min_value(cls) -> Optional[Union[int, float]]: def min_value(self) -> Optional[Union[int, float]]:
"""Optional[:class:`int`]: The minimum supported value for this parameter. """Optional[:class:`int`]: The minimum supported value for this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``. Defaults to ``None``.
""" """
return None return None
@classmethod @property
def max_value(cls) -> Optional[Union[int, float]]: def max_value(self) -> Optional[Union[int, float]]:
"""Optional[:class:`int`]: The maximum supported value for this parameter. """Optional[:class:`int`]: The maximum supported value for this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``.
"""
return None
@property
def choices(self) -> Optional[List[Choice[Union[int, float, str]]]]:
"""Optional[List[:class:`~discord.app_commands.Choice`]]: A list of choices that are allowed to this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``. Defaults to ``None``.
""" """
return None return None
@classmethod @property
async def transform(cls, interaction: Interaction, value: Any) -> Any: def _error_display_name(self) -> str:
name = self.__class__.__name__
if name.endswith('Transformer'):
return name[:-11]
else:
return name
async def transform(self, interaction: Interaction, value: Any) -> Any:
"""|maybecoro| """|maybecoro|
Transforms the converted option value into another value. Transforms the converted option value into another value.
@ -257,9 +290,8 @@ class Transformer:
""" """
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError('Derived classes need to implement this.')
@classmethod
async def autocomplete( async def autocomplete(
cls, interaction: Interaction, value: Union[int, float, str] self, interaction: Interaction, value: Union[int, float, str]
) -> List[Choice[Union[int, float, str]]]: ) -> List[Choice[Union[int, float, str]]]:
"""|coro| """|coro|
@ -287,42 +319,44 @@ class Transformer:
raise NotImplementedError('Derived classes can implement this.') raise NotImplementedError('Derived classes can implement this.')
class _TransformMetadata: class IdentityTransformer(Transformer):
__discord_app_commands_transform__: ClassVar[bool] = True def __init__(self, type: AppCommandOptionType) -> None:
__slots__ = ('metadata',) self._type = type
def __init__(self, metadata: Type[Transformer]):
self.metadata: Type[Transformer] = metadata
# This is needed to pass typing's type checks.
# e.g. Optional[Transform[discord.Member, MyTransformer]]
def __call__(self) -> None:
pass
@property
def type(self) -> AppCommandOptionType:
return self._type
async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any: async def transform(self, interaction: Interaction, value: Any) -> Any:
return value return value
def _make_range_transformer( class RangeTransformer(IdentityTransformer):
def __init__(
self,
opt_type: AppCommandOptionType, opt_type: AppCommandOptionType,
*, *,
min: Optional[Union[int, float]] = None, min: Optional[Union[int, float]] = None,
max: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None,
) -> Type[Transformer]: ) -> None:
if min and max and min > max: if min and max and min > max:
raise TypeError('minimum cannot be larger than maximum') raise TypeError('minimum cannot be larger than maximum')
ns = { self._min: Optional[Union[int, float]] = min
'type': classmethod(lambda _: opt_type), self._max: Optional[Union[int, float]] = max
'min_value': classmethod(lambda _: min), super().__init__(opt_type)
'max_value': classmethod(lambda _: max),
'transform': classmethod(_identity_transform), @property
} def min_value(self) -> Optional[Union[int, float]]:
return type('RangeTransformer', (Transformer,), ns) return self._min
@property
def max_value(self) -> Optional[Union[int, float]]:
return self._max
def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]: class LiteralTransformer(IdentityTransformer):
def __init__(self, values: Tuple[Any, ...]) -> None:
first = type(values[0]) first = type(values[0])
if first is int: if first is int:
opt_type = AppCommandOptionType.integer opt_type = AppCommandOptionType.integer
@ -333,15 +367,18 @@ def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]:
else: else:
raise TypeError(f'expected int, str, or float values not {first!r}') raise TypeError(f'expected int, str, or float values not {first!r}')
ns = { self._choices = [Choice(name=str(v), value=v) for v in values]
'type': classmethod(lambda _: opt_type), super().__init__(opt_type)
'transform': classmethod(_identity_transform),
'__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values], @property
} def choices(self):
return type('LiteralTransformer', (Transformer,), ns) return self._choices
def _make_choice_transformer(inner_type: Any) -> Type[Transformer]: class ChoiceTransformer(IdentityTransformer):
__discord_app_commands_is_choice__: ClassVar[bool] = True
def __init__(self, inner_type: Any) -> None:
if inner_type is int: if inner_type is int:
opt_type = AppCommandOptionType.integer opt_type = AppCommandOptionType.integer
elif inner_type is float: elif inner_type is float:
@ -351,15 +388,13 @@ def _make_choice_transformer(inner_type: Any) -> Type[Transformer]:
else: else:
raise TypeError(f'expected int, str, or float values not {inner_type!r}') raise TypeError(f'expected int, str, or float values not {inner_type!r}')
ns = { super().__init__(opt_type)
'type': classmethod(lambda _: opt_type),
'transform': classmethod(_identity_transform),
'__discord_app_commands_is_choice__': True,
}
return type('ChoiceTransformer', (Transformer,), ns)
def _make_enum_transformer(enum) -> Type[Transformer]: class EnumValueTransformer(Transformer):
def __init__(self, enum: Any) -> None:
super().__init__()
values = list(enum) values = list(enum)
if len(values) < 2: if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.') raise TypeError(f'enum.Enum requires at least two values.')
@ -374,35 +409,51 @@ def _make_enum_transformer(enum) -> Type[Transformer]:
else: else:
raise TypeError(f'expected int, str, or float values not {first!r}') raise TypeError(f'expected int, str, or float values not {first!r}')
async def transform(cls, interaction: Interaction, value: Any) -> Any: self._type: AppCommandOptionType = opt_type
return enum(value) self._enum: Any = enum
self._choices = [Choice(name=v.name, value=v.value) for v in values]
ns = { @property
'type': classmethod(lambda _: opt_type), def _error_display_name(self) -> str:
'transform': classmethod(transform), return self._enum.__name__
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values], @property
} def type(self) -> AppCommandOptionType:
return self._type
@property
def choices(self):
return self._choices
async def transform(self, interaction: Interaction, value: Any) -> Any:
return self._enum(value)
return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns)
class EnumNameTransformer(Transformer):
def __init__(self, enum: Any) -> None:
super().__init__()
def _make_complex_enum_transformer(enum) -> Type[Transformer]:
values = list(enum) values = list(enum)
if len(values) < 2: if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.') raise TypeError(f'enum.Enum requires at least two values.')
async def transform(cls, interaction: Interaction, value: Any) -> Any: self._enum: Any = enum
return enum[value] self._choices = [Choice(name=v.name, value=v.value) for v in values]
ns = { @property
'type': classmethod(lambda _: AppCommandOptionType.string), def _error_display_name(self) -> str:
'transform': classmethod(transform), return self._enum.__name__
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.name) for v in values],
}
return type(f'{enum.__name__}ComplexEnumTransformer', (Transformer,), ns) @property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.string
@property
def choices(self):
return self._choices
async def transform(self, interaction: Interaction, value: Any) -> Any:
return self._enum[value]
if TYPE_CHECKING: if TYPE_CHECKING:
@ -433,11 +484,14 @@ else:
_, transformer = items _, transformer = items
is_valid = inspect.isclass(transformer) and issubclass(transformer, Transformer) if inspect.isclass(transformer):
if not is_valid: if not issubclass(transformer, Transformer):
raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}') raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}')
transformer = transformer()
elif not isinstance(transformer, Transformer):
raise TypeError(f'second argument of Transform must be a Transformer not {transformer.__class__!r}')
return _TransformMetadata(transformer) return transformer
class Range: class Range:
"""A type annotation that can be applied to a parameter to require a numeric or string """A type annotation that can be applied to a parameter to require a numeric or string
@ -497,71 +551,33 @@ else:
else: else:
cast = float cast = float
transformer = _make_range_transformer( transformer = RangeTransformer(
opt_type, opt_type,
min=cast(min) if min is not None else None, min=cast(min) if min is not None else None,
max=cast(max) if max is not None else None, max=cast(max) if max is not None else None,
) )
return _TransformMetadata(transformer) return transformer
def passthrough_transformer(opt_type: AppCommandOptionType) -> Type[Transformer]:
class _Generated(Transformer):
@classmethod
def type(cls) -> AppCommandOptionType:
return opt_type
@classmethod
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return value
return _Generated
class MemberTransformer(Transformer): class MemberTransformer(Transformer):
@classmethod @property
def type(cls) -> AppCommandOptionType: def type(self) -> AppCommandOptionType:
return AppCommandOptionType.user return AppCommandOptionType.user
@classmethod async def transform(self, interaction: Interaction, value: Any) -> Member:
async def transform(cls, interaction: Interaction, value: Any) -> Member:
if not isinstance(value, Member): if not isinstance(value, Member):
raise TransformerError(value, cls.type(), cls) raise TransformerError(value, self.type, self)
return value
def channel_transformer(*channel_types: Type[Any], raw: Optional[bool] = False) -> Type[Transformer]:
if raw:
async def transform(cls, interaction: Interaction, value: Any):
if not isinstance(value, channel_types):
raise TransformerError(value, AppCommandOptionType.channel, cls)
return value
elif raw is False:
async def transform(cls, interaction: Interaction, value: Any):
resolved = value.resolve()
if resolved is None or not isinstance(resolved, channel_types):
raise TransformerError(value, AppCommandOptionType.channel, cls)
return resolved
else:
async def transform(cls, interaction: Interaction, value: Any):
if isinstance(value, channel_types):
return value return value
resolved = value.resolve()
if resolved is None or not isinstance(resolved, channel_types):
raise TransformerError(value, AppCommandOptionType.channel, cls)
return resolved
class BaseChannelTransformer(Transformer):
def __init__(self, *channel_types: Type[Any]) -> None:
super().__init__()
if len(channel_types) == 1: if len(channel_types) == 1:
name = channel_types[0].__name__ display_name = channel_types[0].__name__
types = CHANNEL_TO_TYPES[channel_types[0]] types = CHANNEL_TO_TYPES[channel_types[0]]
else: else:
name = 'MultiChannel' display_name = '{}, and {}'.format(', '.join(t.__name__ for t in channel_types[:-1]), channel_types[-1].__name__)
types = [] types = []
for t in channel_types: for t in channel_types:
@ -570,15 +586,45 @@ def channel_transformer(*channel_types: Type[Any], raw: Optional[bool] = False)
except KeyError: except KeyError:
raise TypeError(f'Union type of channels must be entirely made up of channels') from None raise TypeError(f'Union type of channels must be entirely made up of channels') from None
return type( self._types: Tuple[Type[Any]] = channel_types
f'{name}Transformer', self._channel_types: List[ChannelType] = types
(Transformer,), self._display_name = display_name
{
'type': classmethod(lambda cls: AppCommandOptionType.channel), @property
'transform': classmethod(transform), def _error_display_name(self) -> str:
'channel_types': classmethod(lambda cls: types), return self._display_name
},
) @property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.channel
@property
def channel_types(self) -> List[ChannelType]:
return self._channel_types
async def transform(self, interaction: Interaction, value: Any):
resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
class RawChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any):
if not isinstance(value, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return value
class UnionChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any):
if isinstance(value, self._types):
return value
resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = { CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
@ -604,23 +650,23 @@ CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
CategoryChannel: [ChannelType.category], CategoryChannel: [ChannelType.category],
} }
BUILT_IN_TRANSFORMERS: Dict[Any, Type[Transformer]] = { BUILT_IN_TRANSFORMERS: Dict[Any, Transformer] = {
str: passthrough_transformer(AppCommandOptionType.string), str: IdentityTransformer(AppCommandOptionType.string),
int: passthrough_transformer(AppCommandOptionType.integer), int: IdentityTransformer(AppCommandOptionType.integer),
float: passthrough_transformer(AppCommandOptionType.number), float: IdentityTransformer(AppCommandOptionType.number),
bool: passthrough_transformer(AppCommandOptionType.boolean), bool: IdentityTransformer(AppCommandOptionType.boolean),
User: passthrough_transformer(AppCommandOptionType.user), User: IdentityTransformer(AppCommandOptionType.user),
Member: MemberTransformer, Member: MemberTransformer(),
Role: passthrough_transformer(AppCommandOptionType.role), Role: IdentityTransformer(AppCommandOptionType.role),
AppCommandChannel: channel_transformer(AppCommandChannel, raw=True), AppCommandChannel: RawChannelTransformer(AppCommandChannel),
AppCommandThread: channel_transformer(AppCommandThread, raw=True), AppCommandThread: RawChannelTransformer(AppCommandThread),
GuildChannel: channel_transformer(GuildChannel), GuildChannel: BaseChannelTransformer(GuildChannel),
Thread: channel_transformer(Thread), Thread: BaseChannelTransformer(Thread),
StageChannel: channel_transformer(StageChannel), StageChannel: BaseChannelTransformer(StageChannel),
VoiceChannel: channel_transformer(VoiceChannel), VoiceChannel: BaseChannelTransformer(VoiceChannel),
TextChannel: channel_transformer(TextChannel), TextChannel: BaseChannelTransformer(TextChannel),
CategoryChannel: channel_transformer(CategoryChannel), CategoryChannel: BaseChannelTransformer(CategoryChannel),
Attachment: passthrough_transformer(AppCommandOptionType.attachment), Attachment: IdentityTransformer(AppCommandOptionType.attachment),
} }
ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = { ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
@ -635,7 +681,7 @@ def get_supported_annotation(
annotation: Any, annotation: Any,
*, *,
_none: type = NoneType, _none: type = NoneType,
_mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS, _mapping: Dict[Any, Transformer] = BUILT_IN_TRANSFORMERS,
) -> Tuple[Any, Any, bool]: ) -> Tuple[Any, Any, bool]:
"""Returns an appropriate, yet supported, annotation along with an optional default value. """Returns an appropriate, yet supported, annotation along with an optional default value.
@ -650,20 +696,20 @@ def get_supported_annotation(
except KeyError: except KeyError:
pass pass
if hasattr(annotation, '__discord_app_commands_transform__'): if isinstance(annotation, Transformer):
return (annotation.metadata, MISSING, False) return (annotation, MISSING, False)
if hasattr(annotation, '__metadata__'): if hasattr(annotation, '__metadata__'):
return get_supported_annotation(annotation.__metadata__[0]) return get_supported_annotation(annotation.__metadata__[0])
if inspect.isclass(annotation): if inspect.isclass(annotation):
if issubclass(annotation, Transformer): if issubclass(annotation, Transformer):
return (annotation, MISSING, False) return (annotation(), MISSING, False)
if issubclass(annotation, (Enum, InternalEnum)): if issubclass(annotation, (Enum, InternalEnum)):
if all(isinstance(v.value, (str, int, float)) for v in annotation): if all(isinstance(v.value, (str, int, float)) for v in annotation):
return (_make_enum_transformer(annotation), MISSING, False) return (EnumValueTransformer(annotation), MISSING, False)
else: else:
return (_make_complex_enum_transformer(annotation), MISSING, False) return (EnumNameTransformer(annotation), MISSING, False)
if annotation is Choice: if annotation is Choice:
raise TypeError(f'Choice requires a type argument of int, str, or float') raise TypeError(f'Choice requires a type argument of int, str, or float')
@ -671,11 +717,11 @@ def get_supported_annotation(
origin = getattr(annotation, '__origin__', None) origin = getattr(annotation, '__origin__', None)
if origin is Literal: if origin is Literal:
args = annotation.__args__ # type: ignore args = annotation.__args__ # type: ignore
return (_make_literal_transformer(args), MISSING, True) return (LiteralTransformer(args), MISSING, True)
if origin is Choice: if origin is Choice:
arg = annotation.__args__[0] # type: ignore arg = annotation.__args__[0] # type: ignore
return (_make_choice_transformer(arg), MISSING, True) return (ChoiceTransformer(arg), MISSING, True)
if origin is not Union: if origin is not Union:
# Only Union/Optional is supported right now so bail early # Only Union/Optional is supported right now so bail early
@ -697,7 +743,7 @@ def get_supported_annotation(
# Check for channel union types # Check for channel union types
if any(arg in CHANNEL_TO_TYPES for arg in args): if any(arg in CHANNEL_TO_TYPES for arg in args):
# If any channel type is given, then *all* must be channel types # If any channel type is given, then *all* must be channel types
return (channel_transformer(*args, raw=None), default, True) return (UnionChannelTransformer(*args), default, True)
# The only valid transformations here are: # The only valid transformations here are:
# [Member, User] => user # [Member, User] => user
@ -707,9 +753,9 @@ def get_supported_annotation(
if not all(arg in supported_types for arg in args): if not all(arg in supported_types for arg in args):
raise TypeError(f'unsupported types given inside {annotation!r}') raise TypeError(f'unsupported types given inside {annotation!r}')
if args == (User, Member) or args == (Member, User): if args == (User, Member) or args == (Member, User):
return (passthrough_transformer(AppCommandOptionType.user), default, True) return (IdentityTransformer(AppCommandOptionType.user), default, True)
return (passthrough_transformer(AppCommandOptionType.mentionable), default, True) return (IdentityTransformer(AppCommandOptionType.mentionable), default, True)
def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter: def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter:
@ -721,7 +767,7 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
""" """
(inner, default, validate_default) = get_supported_annotation(annotation) (inner, default, validate_default) = get_supported_annotation(annotation)
type = inner.type() type = inner.type
if default is MISSING or default is None: if default is MISSING or default is None:
param_default = parameter.default param_default = parameter.default
@ -742,26 +788,23 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
name=parameter.name, name=parameter.name,
) )
try: choices = inner.choices
choices = inner.__discord_app_commands_transformer_choices__ if choices is not None:
except AttributeError:
pass
else:
result.choices = choices result.choices = choices
# These methods should be duck typed # These methods should be duck typed
if type in (AppCommandOptionType.number, AppCommandOptionType.string, AppCommandOptionType.integer): if type in (AppCommandOptionType.number, AppCommandOptionType.string, AppCommandOptionType.integer):
result.min_value = inner.min_value() result.min_value = inner.min_value
result.max_value = inner.max_value() result.max_value = inner.max_value
if type is AppCommandOptionType.channel: if type is AppCommandOptionType.channel:
result.channel_types = inner.channel_types() result.channel_types = inner.channel_types
if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL): if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL):
raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}') raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}')
autocomplete_func = getattr(inner.autocomplete, '__func__', inner.autocomplete) # Check if the method is overridden
if autocomplete_func is not Transformer.autocomplete.__func__: if inner.autocomplete.__func__ is not Transformer.autocomplete:
from .commands import _validate_auto_complete_callback from .commands import _validate_auto_complete_callback
result.autocomplete = _validate_auto_complete_callback(inner.autocomplete, skip_binding=True) result.autocomplete = _validate_auto_complete_callback(inner.autocomplete, skip_binding=True)

3
discord/ext/commands/core.py

@ -165,9 +165,6 @@ def get_signature_parameters(
if len(metadata) >= 1: if len(metadata) >= 1:
annotation = metadata[0] annotation = metadata[0]
if isinstance(annotation, discord.app_commands.transformers._TransformMetadata):
annotation = annotation.metadata
params[name] = parameter.replace(annotation=annotation) params[name] = parameter.replace(annotation=annotation)
return params return params

49
discord/ext/commands/hybrid.py

@ -116,18 +116,24 @@ def required_pos_arguments(func: Callable[..., Any]) -> int:
return sum(p.default is p.empty for p in sig.parameters.values()) return sum(p.default is p.empty for p in sig.parameters.values())
def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app_commands.Transformer]: class ConverterTransformer(app_commands.Transformer):
def __init__(self, converter: Any, parameter: Parameter) -> None:
super().__init__()
self.converter: Any = converter
self.parameter: Parameter = parameter
try: try:
module = converter.__module__ module = converter.__module__
except AttributeError: except AttributeError:
pass pass
else: else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')): if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
converter = CONVERTER_MAPPING.get(converter, converter) self.converter = CONVERTER_MAPPING.get(converter, converter)
async def transform(cls, interaction: discord.Interaction, value: str) -> Any: async def transform(self, interaction: discord.Interaction, value: str) -> Any:
ctx = interaction._baton ctx = interaction._baton
ctx.current_parameter = parameter converter = self.converter
ctx.current_parameter = self.parameter
ctx.current_argument = value ctx.current_argument = value
try: try:
if inspect.isclass(converter) and issubclass(converter, Converter): if inspect.isclass(converter) and issubclass(converter, Converter):
@ -142,27 +148,34 @@ def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app
except Exception as exc: except Exception as exc:
raise ConversionError(converter, exc) from exc # type: ignore raise ConversionError(converter, exc) from exc # type: ignore
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
class CallableTransformer(app_commands.Transformer):
def __init__(self, func: Callable[[str], Any]) -> None:
super().__init__()
self.func: Callable[[str], Any] = func
def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]: async def transform(self, interaction: discord.Interaction, value: str) -> Any:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
try: try:
return func(value) return self.func(value)
except CommandError: except CommandError:
raise raise
except Exception as exc: except Exception as exc:
raise BadArgument(f'Converting to "{func.__name__}" failed') from exc raise BadArgument(f'Converting to "{self.func.__name__}" failed') from exc
return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
class GreedyTransformer(app_commands.Transformer):
def __init__(self, converter: Any, parameter: Parameter) -> None:
super().__init__()
self.converter: Any = converter
self.parameter: Parameter = parameter
def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_commands.Transformer]: async def transform(self, interaction: discord.Interaction, value: str) -> Any:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
view = StringView(value) view = StringView(value)
result = [] result = []
ctx = interaction._baton ctx = interaction._baton
ctx.current_parameter = parameter ctx.current_parameter = parameter = self.parameter
converter = self.converter
while True: while True:
view.skip_ws() view.skip_ws()
ctx.current_argument = arg = view.get_quoted_word() ctx.current_argument = arg = view.get_quoted_word()
@ -175,8 +188,6 @@ def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_co
return result return result
return type('GreedyTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def replace_parameter( def replace_parameter(
param: inspect.Parameter, param: inspect.Parameter,
@ -203,7 +214,7 @@ def replace_parameter(
if inner is discord.Attachment: if inner is discord.Attachment:
raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands') raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands')
param = param.replace(annotation=make_greedy_transformer(inner, original)) param = param.replace(annotation=GreedyTransformer(inner, original))
elif is_flag(converter): elif is_flag(converter):
callback.__hybrid_command_flag__ = (param.name, converter) callback.__hybrid_command_flag__ = (param.name, converter)
descriptions = {} descriptions = {}
@ -233,14 +244,14 @@ def replace_parameter(
app_commands.rename(**renames)(callback) app_commands.rename(**renames)(callback)
elif is_converter(converter) or converter in CONVERTER_MAPPING: elif is_converter(converter) or converter in CONVERTER_MAPPING:
param = param.replace(annotation=make_converter_transformer(converter, original)) param = param.replace(annotation=ConverterTransformer(converter, original))
elif origin is Union: elif origin is Union:
if len(args) == 2 and args[-1] is _NoneType: if len(args) == 2 and args[-1] is _NoneType:
# Special case Optional[X] where X is a single type that can optionally be a converter # Special case Optional[X] where X is a single type that can optionally be a converter
inner = args[0] inner = args[0]
is_inner_transformer = is_transformer(inner) is_inner_transformer = is_transformer(inner)
if is_converter(inner) and not is_inner_transformer: if is_converter(inner) and not is_inner_transformer:
param = param.replace(annotation=Optional[make_converter_transformer(inner, original)]) # type: ignore param = param.replace(annotation=Optional[ConverterTransformer(inner, original)]) # type: ignore
else: else:
raise raise
elif origin: elif origin:
@ -250,7 +261,7 @@ def replace_parameter(
param_count = required_pos_arguments(converter) param_count = required_pos_arguments(converter)
if param_count != 1: if param_count != 1:
raise raise
param = param.replace(annotation=make_callable_transformer(converter)) param = param.replace(annotation=CallableTransformer(converter))
return param return param

Loading…
Cancel
Save