Browse Source

Refactor transformers to use instances instead of classmethods

This should allow them to be easier to use for users without requiring
a lot of metaprogramming hackery if you want to involve state.
pull/8280/head
Rapptz 3 years ago
parent
commit
11618cd1ba
  1. 20
      discord/app_commands/errors.py
  2. 503
      discord/app_commands/transformers.py
  3. 3
      discord/ext/commands/core.py
  4. 63
      discord/ext/commands/hybrid.py

20
discord/app_commands/errors.py

@ -121,28 +121,16 @@ class TransformerError(AppCommandError):
The value that failed to convert.
type: :class:`~discord.AppCommandOptionType`
The type of argument that failed to convert.
transformer: Type[:class:`Transformer`]
transformer: :class:`Transformer`
The transformer that failed the conversion.
"""
def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Type[Transformer]):
def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Transformer):
self.value: Any = value
self.type: AppCommandOptionType = opt_type
self.transformer: Type[Transformer] = transformer
try:
result_type = transformer.transform.__annotations__['return']
except KeyError:
name = transformer.__name__
if name.endswith('Transformer'):
result_type = name[:-11]
else:
result_type = name
else:
if isinstance(result_type, type):
result_type = result_type.__name__
self.transformer: Transformer = transformer
super().__init__(f'Failed to convert {value} to {result_type!s}')
super().__init__(f'Failed to convert {value} to {transformer._error_display_name!s}')
class CheckFailure(AppCommandError):

503
discord/app_commands/transformers.py

@ -166,7 +166,7 @@ class Transformer:
to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one
from this type.
This class is customisable through the overriding of :func:`classmethod` in the class
This class is customisable through the overriding of methods and properties in the class
and by using it as the second type parameter of the :class:`~discord.app_commands.Transform`
class. For example, to convert a string into a custom pair type:
@ -177,8 +177,7 @@ class Transformer:
y: int
class PointTransformer(app_commands.Transformer):
@classmethod
async def transform(cls, interaction: discord.Interaction, value: str) -> Point:
async def transform(self, interaction: discord.Interaction, value: str) -> Point:
(x, _, y) = value.partition(',')
return Point(x=int(x.strip()), y=int(y.strip()))
@ -189,56 +188,90 @@ class Transformer:
):
await interaction.response.send_message(str(point))
If a class is passed instead of an instance to the second type parameter, then it is
constructed with no arguments passed to the ``__init__`` method.
.. versionadded:: 2.0
"""
__discord_app_commands_transformer__: ClassVar[bool] = True
__discord_app_commands_is_choice__: ClassVar[bool] = False
@classmethod
def type(cls) -> AppCommandOptionType:
# This is needed to pass typing's type checks.
# e.g. Optional[MyTransformer]
def __call__(self) -> None:
pass
@property
def type(self) -> AppCommandOptionType:
""":class:`~discord.AppCommandOptionType`: The option type associated with this transformer.
This must be a :obj:`classmethod`.
This must be a :obj:`property`.
Defaults to :attr:`~discord.AppCommandOptionType.string`.
"""
return AppCommandOptionType.string
@classmethod
def channel_types(cls) -> List[ChannelType]:
@property
def channel_types(self) -> List[ChannelType]:
"""List[:class:`~discord.ChannelType`]: A list of channel types that are allowed to this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.channel`.
This must be a :obj:`property`.
Defaults to an empty list.
"""
return []
@classmethod
def min_value(cls) -> Optional[Union[int, float]]:
@property
def min_value(self) -> Optional[Union[int, float]]:
"""Optional[:class:`int`]: The minimum supported value for this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``.
"""
return None
@classmethod
def max_value(cls) -> Optional[Union[int, float]]:
@property
def max_value(self) -> Optional[Union[int, float]]:
"""Optional[:class:`int`]: The maximum supported value for this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``.
"""
return None
@classmethod
async def transform(cls, interaction: Interaction, value: Any) -> Any:
@property
def choices(self) -> Optional[List[Choice[Union[int, float, str]]]]:
"""Optional[List[:class:`~discord.app_commands.Choice`]]: A list of choices that are allowed to this parameter.
Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number`
:attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`.
This must be a :obj:`property`.
Defaults to ``None``.
"""
return None
@property
def _error_display_name(self) -> str:
name = self.__class__.__name__
if name.endswith('Transformer'):
return name[:-11]
else:
return name
async def transform(self, interaction: Interaction, value: Any) -> Any:
"""|maybecoro|
Transforms the converted option value into another value.
@ -257,9 +290,8 @@ class Transformer:
"""
raise NotImplementedError('Derived classes need to implement this.')
@classmethod
async def autocomplete(
cls, interaction: Interaction, value: Union[int, float, str]
self, interaction: Interaction, value: Union[int, float, str]
) -> List[Choice[Union[int, float, str]]]:
"""|coro|
@ -287,122 +319,141 @@ class Transformer:
raise NotImplementedError('Derived classes can implement this.')
class _TransformMetadata:
__discord_app_commands_transform__: ClassVar[bool] = True
__slots__ = ('metadata',)
class IdentityTransformer(Transformer):
def __init__(self, type: AppCommandOptionType) -> None:
self._type = type
def __init__(self, metadata: Type[Transformer]):
self.metadata: Type[Transformer] = metadata
@property
def type(self) -> AppCommandOptionType:
return self._type
# This is needed to pass typing's type checks.
# e.g. Optional[Transform[discord.Member, MyTransformer]]
def __call__(self) -> None:
pass
async def transform(self, interaction: Interaction, value: Any) -> Any:
return value
async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any:
return value
class RangeTransformer(IdentityTransformer):
def __init__(
self,
opt_type: AppCommandOptionType,
*,
min: Optional[Union[int, float]] = None,
max: Optional[Union[int, float]] = None,
) -> None:
if min and max and min > max:
raise TypeError('minimum cannot be larger than maximum')
self._min: Optional[Union[int, float]] = min
self._max: Optional[Union[int, float]] = max
super().__init__(opt_type)
@property
def min_value(self) -> Optional[Union[int, float]]:
return self._min
@property
def max_value(self) -> Optional[Union[int, float]]:
return self._max
class LiteralTransformer(IdentityTransformer):
def __init__(self, values: Tuple[Any, ...]) -> None:
first = type(values[0])
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
self._choices = [Choice(name=str(v), value=v) for v in values]
super().__init__(opt_type)
@property
def choices(self):
return self._choices
class ChoiceTransformer(IdentityTransformer):
__discord_app_commands_is_choice__: ClassVar[bool] = True
def __init__(self, inner_type: Any) -> None:
if inner_type is int:
opt_type = AppCommandOptionType.integer
elif inner_type is float:
opt_type = AppCommandOptionType.number
elif inner_type is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {inner_type!r}')
super().__init__(opt_type)
def _make_range_transformer(
opt_type: AppCommandOptionType,
*,
min: Optional[Union[int, float]] = None,
max: Optional[Union[int, float]] = None,
) -> Type[Transformer]:
if min and max and min > max:
raise TypeError('minimum cannot be larger than maximum')
ns = {
'type': classmethod(lambda _: opt_type),
'min_value': classmethod(lambda _: min),
'max_value': classmethod(lambda _: max),
'transform': classmethod(_identity_transform),
}
return type('RangeTransformer', (Transformer,), ns)
def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]:
first = type(values[0])
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(_identity_transform),
'__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values],
}
return type('LiteralTransformer', (Transformer,), ns)
def _make_choice_transformer(inner_type: Any) -> Type[Transformer]:
if inner_type is int:
opt_type = AppCommandOptionType.integer
elif inner_type is float:
opt_type = AppCommandOptionType.number
elif inner_type is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {inner_type!r}')
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(_identity_transform),
'__discord_app_commands_is_choice__': True,
}
return type('ChoiceTransformer', (Transformer,), ns)
def _make_enum_transformer(enum) -> Type[Transformer]:
values = list(enum)
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.')
first = type(values[0].value)
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return enum(value)
ns = {
'type': classmethod(lambda _: opt_type),
'transform': classmethod(transform),
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values],
}
return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns)
def _make_complex_enum_transformer(enum) -> Type[Transformer]:
values = list(enum)
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.')
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return enum[value]
ns = {
'type': classmethod(lambda _: AppCommandOptionType.string),
'transform': classmethod(transform),
'__discord_app_commands_transformer_enum__': enum,
'__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.name) for v in values],
}
return type(f'{enum.__name__}ComplexEnumTransformer', (Transformer,), ns)
class EnumValueTransformer(Transformer):
def __init__(self, enum: Any) -> None:
super().__init__()
values = list(enum)
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.')
first = type(values[0].value)
if first is int:
opt_type = AppCommandOptionType.integer
elif first is float:
opt_type = AppCommandOptionType.number
elif first is str:
opt_type = AppCommandOptionType.string
else:
raise TypeError(f'expected int, str, or float values not {first!r}')
self._type: AppCommandOptionType = opt_type
self._enum: Any = enum
self._choices = [Choice(name=v.name, value=v.value) for v in values]
@property
def _error_display_name(self) -> str:
return self._enum.__name__
@property
def type(self) -> AppCommandOptionType:
return self._type
@property
def choices(self):
return self._choices
async def transform(self, interaction: Interaction, value: Any) -> Any:
return self._enum(value)
class EnumNameTransformer(Transformer):
def __init__(self, enum: Any) -> None:
super().__init__()
values = list(enum)
if len(values) < 2:
raise TypeError(f'enum.Enum requires at least two values.')
self._enum: Any = enum
self._choices = [Choice(name=v.name, value=v.value) for v in values]
@property
def _error_display_name(self) -> str:
return self._enum.__name__
@property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.string
@property
def choices(self):
return self._choices
async def transform(self, interaction: Interaction, value: Any) -> Any:
return self._enum[value]
if TYPE_CHECKING:
@ -433,11 +484,14 @@ else:
_, transformer = items
is_valid = inspect.isclass(transformer) and issubclass(transformer, Transformer)
if not is_valid:
raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}')
if inspect.isclass(transformer):
if not issubclass(transformer, Transformer):
raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}')
transformer = transformer()
elif not isinstance(transformer, Transformer):
raise TypeError(f'second argument of Transform must be a Transformer not {transformer.__class__!r}')
return _TransformMetadata(transformer)
return transformer
class Range:
"""A type annotation that can be applied to a parameter to require a numeric or string
@ -497,88 +551,80 @@ else:
else:
cast = float
transformer = _make_range_transformer(
transformer = RangeTransformer(
opt_type,
min=cast(min) if min is not None else None,
max=cast(max) if max is not None else None,
)
return _TransformMetadata(transformer)
def passthrough_transformer(opt_type: AppCommandOptionType) -> Type[Transformer]:
class _Generated(Transformer):
@classmethod
def type(cls) -> AppCommandOptionType:
return opt_type
@classmethod
async def transform(cls, interaction: Interaction, value: Any) -> Any:
return value
return _Generated
return transformer
class MemberTransformer(Transformer):
@classmethod
def type(cls) -> AppCommandOptionType:
@property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.user
@classmethod
async def transform(cls, interaction: Interaction, value: Any) -> Member:
async def transform(self, interaction: Interaction, value: Any) -> Member:
if not isinstance(value, Member):
raise TransformerError(value, cls.type(), cls)
raise TransformerError(value, self.type, self)
return value
def channel_transformer(*channel_types: Type[Any], raw: Optional[bool] = False) -> Type[Transformer]:
if raw:
class BaseChannelTransformer(Transformer):
def __init__(self, *channel_types: Type[Any]) -> None:
super().__init__()
if len(channel_types) == 1:
display_name = channel_types[0].__name__
types = CHANNEL_TO_TYPES[channel_types[0]]
else:
display_name = '{}, and {}'.format(', '.join(t.__name__ for t in channel_types[:-1]), channel_types[-1].__name__)
types = []
async def transform(cls, interaction: Interaction, value: Any):
if not isinstance(value, channel_types):
raise TransformerError(value, AppCommandOptionType.channel, cls)
return value
for t in channel_types:
try:
types.extend(CHANNEL_TO_TYPES[t])
except KeyError:
raise TypeError(f'Union type of channels must be entirely made up of channels') from None
elif raw is False:
self._types: Tuple[Type[Any]] = channel_types
self._channel_types: List[ChannelType] = types
self._display_name = display_name
async def transform(cls, interaction: Interaction, value: Any):
resolved = value.resolve()
if resolved is None or not isinstance(resolved, channel_types):
raise TransformerError(value, AppCommandOptionType.channel, cls)
return resolved
@property
def _error_display_name(self) -> str:
return self._display_name
else:
@property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.channel
async def transform(cls, interaction: Interaction, value: Any):
if isinstance(value, channel_types):
return value
@property
def channel_types(self) -> List[ChannelType]:
return self._channel_types
resolved = value.resolve()
if resolved is None or not isinstance(resolved, channel_types):
raise TransformerError(value, AppCommandOptionType.channel, cls)
return resolved
async def transform(self, interaction: Interaction, value: Any):
resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
if len(channel_types) == 1:
name = channel_types[0].__name__
types = CHANNEL_TO_TYPES[channel_types[0]]
else:
name = 'MultiChannel'
types = []
for t in channel_types:
try:
types.extend(CHANNEL_TO_TYPES[t])
except KeyError:
raise TypeError(f'Union type of channels must be entirely made up of channels') from None
return type(
f'{name}Transformer',
(Transformer,),
{
'type': classmethod(lambda cls: AppCommandOptionType.channel),
'transform': classmethod(transform),
'channel_types': classmethod(lambda cls: types),
},
)
class RawChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any):
if not isinstance(value, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return value
class UnionChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any):
if isinstance(value, self._types):
return value
resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
@ -604,23 +650,23 @@ CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
CategoryChannel: [ChannelType.category],
}
BUILT_IN_TRANSFORMERS: Dict[Any, Type[Transformer]] = {
str: passthrough_transformer(AppCommandOptionType.string),
int: passthrough_transformer(AppCommandOptionType.integer),
float: passthrough_transformer(AppCommandOptionType.number),
bool: passthrough_transformer(AppCommandOptionType.boolean),
User: passthrough_transformer(AppCommandOptionType.user),
Member: MemberTransformer,
Role: passthrough_transformer(AppCommandOptionType.role),
AppCommandChannel: channel_transformer(AppCommandChannel, raw=True),
AppCommandThread: channel_transformer(AppCommandThread, raw=True),
GuildChannel: channel_transformer(GuildChannel),
Thread: channel_transformer(Thread),
StageChannel: channel_transformer(StageChannel),
VoiceChannel: channel_transformer(VoiceChannel),
TextChannel: channel_transformer(TextChannel),
CategoryChannel: channel_transformer(CategoryChannel),
Attachment: passthrough_transformer(AppCommandOptionType.attachment),
BUILT_IN_TRANSFORMERS: Dict[Any, Transformer] = {
str: IdentityTransformer(AppCommandOptionType.string),
int: IdentityTransformer(AppCommandOptionType.integer),
float: IdentityTransformer(AppCommandOptionType.number),
bool: IdentityTransformer(AppCommandOptionType.boolean),
User: IdentityTransformer(AppCommandOptionType.user),
Member: MemberTransformer(),
Role: IdentityTransformer(AppCommandOptionType.role),
AppCommandChannel: RawChannelTransformer(AppCommandChannel),
AppCommandThread: RawChannelTransformer(AppCommandThread),
GuildChannel: BaseChannelTransformer(GuildChannel),
Thread: BaseChannelTransformer(Thread),
StageChannel: BaseChannelTransformer(StageChannel),
VoiceChannel: BaseChannelTransformer(VoiceChannel),
TextChannel: BaseChannelTransformer(TextChannel),
CategoryChannel: BaseChannelTransformer(CategoryChannel),
Attachment: IdentityTransformer(AppCommandOptionType.attachment),
}
ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
@ -635,7 +681,7 @@ def get_supported_annotation(
annotation: Any,
*,
_none: type = NoneType,
_mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS,
_mapping: Dict[Any, Transformer] = BUILT_IN_TRANSFORMERS,
) -> Tuple[Any, Any, bool]:
"""Returns an appropriate, yet supported, annotation along with an optional default value.
@ -650,20 +696,20 @@ def get_supported_annotation(
except KeyError:
pass
if hasattr(annotation, '__discord_app_commands_transform__'):
return (annotation.metadata, MISSING, False)
if isinstance(annotation, Transformer):
return (annotation, MISSING, False)
if hasattr(annotation, '__metadata__'):
return get_supported_annotation(annotation.__metadata__[0])
if inspect.isclass(annotation):
if issubclass(annotation, Transformer):
return (annotation, MISSING, False)
return (annotation(), MISSING, False)
if issubclass(annotation, (Enum, InternalEnum)):
if all(isinstance(v.value, (str, int, float)) for v in annotation):
return (_make_enum_transformer(annotation), MISSING, False)
return (EnumValueTransformer(annotation), MISSING, False)
else:
return (_make_complex_enum_transformer(annotation), MISSING, False)
return (EnumNameTransformer(annotation), MISSING, False)
if annotation is Choice:
raise TypeError(f'Choice requires a type argument of int, str, or float')
@ -671,11 +717,11 @@ def get_supported_annotation(
origin = getattr(annotation, '__origin__', None)
if origin is Literal:
args = annotation.__args__ # type: ignore
return (_make_literal_transformer(args), MISSING, True)
return (LiteralTransformer(args), MISSING, True)
if origin is Choice:
arg = annotation.__args__[0] # type: ignore
return (_make_choice_transformer(arg), MISSING, True)
return (ChoiceTransformer(arg), MISSING, True)
if origin is not Union:
# Only Union/Optional is supported right now so bail early
@ -697,7 +743,7 @@ def get_supported_annotation(
# Check for channel union types
if any(arg in CHANNEL_TO_TYPES for arg in args):
# If any channel type is given, then *all* must be channel types
return (channel_transformer(*args, raw=None), default, True)
return (UnionChannelTransformer(*args), default, True)
# The only valid transformations here are:
# [Member, User] => user
@ -707,9 +753,9 @@ def get_supported_annotation(
if not all(arg in supported_types for arg in args):
raise TypeError(f'unsupported types given inside {annotation!r}')
if args == (User, Member) or args == (Member, User):
return (passthrough_transformer(AppCommandOptionType.user), default, True)
return (IdentityTransformer(AppCommandOptionType.user), default, True)
return (passthrough_transformer(AppCommandOptionType.mentionable), default, True)
return (IdentityTransformer(AppCommandOptionType.mentionable), default, True)
def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter:
@ -721,7 +767,7 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
"""
(inner, default, validate_default) = get_supported_annotation(annotation)
type = inner.type()
type = inner.type
if default is MISSING or default is None:
param_default = parameter.default
@ -742,26 +788,23 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
name=parameter.name,
)
try:
choices = inner.__discord_app_commands_transformer_choices__
except AttributeError:
pass
else:
choices = inner.choices
if choices is not None:
result.choices = choices
# These methods should be duck typed
if type in (AppCommandOptionType.number, AppCommandOptionType.string, AppCommandOptionType.integer):
result.min_value = inner.min_value()
result.max_value = inner.max_value()
result.min_value = inner.min_value
result.max_value = inner.max_value
if type is AppCommandOptionType.channel:
result.channel_types = inner.channel_types()
result.channel_types = inner.channel_types
if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL):
raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}')
autocomplete_func = getattr(inner.autocomplete, '__func__', inner.autocomplete)
if autocomplete_func is not Transformer.autocomplete.__func__:
# Check if the method is overridden
if inner.autocomplete.__func__ is not Transformer.autocomplete:
from .commands import _validate_auto_complete_callback
result.autocomplete = _validate_auto_complete_callback(inner.autocomplete, skip_binding=True)

3
discord/ext/commands/core.py

@ -165,9 +165,6 @@ def get_signature_parameters(
if len(metadata) >= 1:
annotation = metadata[0]
if isinstance(annotation, discord.app_commands.transformers._TransformMetadata):
annotation = annotation.metadata
params[name] = parameter.replace(annotation=annotation)
return params

63
discord/ext/commands/hybrid.py

@ -116,18 +116,24 @@ def required_pos_arguments(func: Callable[..., Any]) -> int:
return sum(p.default is p.empty for p in sig.parameters.values())
def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app_commands.Transformer]:
try:
module = converter.__module__
except AttributeError:
pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
converter = CONVERTER_MAPPING.get(converter, converter)
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
class ConverterTransformer(app_commands.Transformer):
def __init__(self, converter: Any, parameter: Parameter) -> None:
super().__init__()
self.converter: Any = converter
self.parameter: Parameter = parameter
try:
module = converter.__module__
except AttributeError:
pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
self.converter = CONVERTER_MAPPING.get(converter, converter)
async def transform(self, interaction: discord.Interaction, value: str) -> Any:
ctx = interaction._baton
ctx.current_parameter = parameter
converter = self.converter
ctx.current_parameter = self.parameter
ctx.current_argument = value
try:
if inspect.isclass(converter) and issubclass(converter, Converter):
@ -142,27 +148,34 @@ def make_converter_transformer(converter: Any, parameter: Parameter) -> Type[app
except Exception as exc:
raise ConversionError(converter, exc) from exc # type: ignore
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
class CallableTransformer(app_commands.Transformer):
def __init__(self, func: Callable[[str], Any]) -> None:
super().__init__()
self.func: Callable[[str], Any] = func
def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
async def transform(self, interaction: discord.Interaction, value: str) -> Any:
try:
return func(value)
return self.func(value)
except CommandError:
raise
except Exception as exc:
raise BadArgument(f'Converting to "{func.__name__}" failed') from exc
raise BadArgument(f'Converting to "{self.func.__name__}" failed') from exc
return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
class GreedyTransformer(app_commands.Transformer):
def __init__(self, converter: Any, parameter: Parameter) -> None:
super().__init__()
self.converter: Any = converter
self.parameter: Parameter = parameter
def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_commands.Transformer]:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
async def transform(self, interaction: discord.Interaction, value: str) -> Any:
view = StringView(value)
result = []
ctx = interaction._baton
ctx.current_parameter = parameter
ctx.current_parameter = parameter = self.parameter
converter = self.converter
while True:
view.skip_ws()
ctx.current_argument = arg = view.get_quoted_word()
@ -175,8 +188,6 @@ def make_greedy_transformer(converter: Any, parameter: Parameter) -> Type[app_co
return result
return type('GreedyTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def replace_parameter(
param: inspect.Parameter,
@ -203,7 +214,7 @@ def replace_parameter(
if inner is discord.Attachment:
raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands')
param = param.replace(annotation=make_greedy_transformer(inner, original))
param = param.replace(annotation=GreedyTransformer(inner, original))
elif is_flag(converter):
callback.__hybrid_command_flag__ = (param.name, converter)
descriptions = {}
@ -233,14 +244,14 @@ def replace_parameter(
app_commands.rename(**renames)(callback)
elif is_converter(converter) or converter in CONVERTER_MAPPING:
param = param.replace(annotation=make_converter_transformer(converter, original))
param = param.replace(annotation=ConverterTransformer(converter, original))
elif origin is Union:
if len(args) == 2 and args[-1] is _NoneType:
# Special case Optional[X] where X is a single type that can optionally be a converter
inner = args[0]
is_inner_transformer = is_transformer(inner)
if is_converter(inner) and not is_inner_transformer:
param = param.replace(annotation=Optional[make_converter_transformer(inner, original)]) # type: ignore
param = param.replace(annotation=Optional[ConverterTransformer(inner, original)]) # type: ignore
else:
raise
elif origin:
@ -250,7 +261,7 @@ def replace_parameter(
param_count = required_pos_arguments(converter)
if param_count != 1:
raise
param = param.replace(annotation=make_callable_transformer(converter))
param = param.replace(annotation=CallableTransformer(converter))
return param

Loading…
Cancel
Save