From c54e43360b027a976e0c6b1a768808f2da8384a0 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 19 Apr 2021 04:46:02 -0400 Subject: [PATCH] [commands] Add run_converters helper to call converters --- discord/ext/commands/converter.py | 177 +++++++++++++++++++++++++++++- discord/ext/commands/core.py | 125 ++------------------- docs/ext/commands/api.rst | 2 + 3 files changed, 188 insertions(+), 116 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index e94a08efb..b12b98049 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -26,7 +26,21 @@ from __future__ import annotations import re import inspect -from typing import Iterable, Optional, TYPE_CHECKING, List, Protocol, Type, TypeVar, Tuple, Union, runtime_checkable +from typing import ( + Any, + Dict, + Iterable, + Literal, + Optional, + TYPE_CHECKING, + List, + Protocol, + Type, + TypeVar, + Tuple, + Union, + runtime_checkable, +) import discord from .errors import * @@ -58,6 +72,7 @@ __all__ = ( 'StoreChannelConverter', 'clean_content', 'Greedy', + 'run_converters', ) @@ -867,3 +882,163 @@ class Greedy(List[T]): raise TypeError(f'Greedy[{converter!r}] is invalid.') return cls(converter=converter) + + +def _convert_to_bool(argument: str) -> bool: + lowered = argument.lower() + if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): + return True + elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): + return False + else: + raise BadBoolArgument(lowered) + + +def get_converter(param: inspect.Parameter) -> Any: + converter = param.annotation + if converter is param.empty: + if param.default is not param.empty: + converter = str if param.default is None else type(param.default) + else: + converter = str + return converter + + +CONVERTER_MAPPING: Dict[Type[Any], Any] = { + discord.Object: ObjectConverter, + discord.Member: MemberConverter, + discord.User: UserConverter, + discord.Message: MessageConverter, + discord.PartialMessage: PartialMessageConverter, + discord.TextChannel: TextChannelConverter, + discord.Invite: InviteConverter, + discord.Guild: GuildConverter, + discord.Role: RoleConverter, + discord.Game: GameConverter, + discord.Colour: ColourConverter, + discord.VoiceChannel: VoiceChannelConverter, + discord.StageChannel: StageChannelConverter, + discord.Emoji: EmojiConverter, + discord.PartialEmoji: PartialEmojiConverter, + discord.CategoryChannel: CategoryChannelConverter, + discord.StoreChannel: StoreChannelConverter, +} + + +async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): + if converter is bool: + return _convert_to_bool(argument) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module is not None and (module.startswith('discord.') and not module.endswith('converter')): + converter = CONVERTER_MAPPING.get(converter, converter) + + try: + if inspect.isclass(converter) and issubclass(converter, Converter): + if inspect.ismethod(converter.convert): + return await converter.convert(ctx, argument) + else: + return await converter().convert(ctx, argument) + elif isinstance(converter, Converter): + return await converter.convert(ctx, argument) + except CommandError: + raise + except Exception as exc: + raise ConversionError(converter, exc) from exc + + try: + return converter(argument) + except CommandError: + raise + except Exception as exc: + try: + name = converter.__name__ + except AttributeError: + name = converter.__class__.__name__ + + raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc + + +async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): + """|coro| + + Runs converters for a given converter, argument, and parameter. + + This function does the same work that the library does under the hood. + + .. versionadded:: 2.0 + + Parameters + ------------ + ctx: :class:`Context` + The invocation context to run the converters under. + converter: Any + The converter to run, this corresponds to the annotation in the function. + argument: :class:`str` + The argument to convert to. + param: :class:`inspect.Parameter` + The parameter being converted. This is mainly for error reporting. + + Raises + ------- + CommandError + The converter failed to convert. + + Returns + -------- + Any + The resulting conversion. + """ + origin = getattr(converter, '__origin__', None) + + if origin is Union: + errors = [] + _NoneType = type(None) + union_args = converter.__args__ + for conv in union_args: + # if we got to this part in the code, then the previous conversions have failed + # so we should just undo the view, return the default, and allow parsing to continue + # with the other parameters + if conv is _NoneType and param.kind != param.VAR_POSITIONAL: + ctx.view.undo() + return None if param.default is param.empty else param.default + + try: + value = await run_converters(ctx, conv, argument, param) + except CommandError as exc: + errors.append(exc) + else: + return value + + # if we're here, then we failed all the converters + raise BadUnionArgument(param, union_args, errors) + + if origin is Literal: + errors = [] + conversions = {} + literal_args = converter.__args__ + for literal in literal_args: + literal_type = type(literal) + try: + value = conversions[literal_type] + except KeyError: + try: + value = await _actual_conversion(ctx, literal_type, argument, param) + except CommandError as exc: + errors.append(exc) + conversions[literal_type] = object() + continue + else: + conversions[literal_type] = value + + if value == literal: + return value + + # if we're here, then we failed to match all the literals + raise BadLiteralArgument(param, literal_args, errors) + + return await _actual_conversion(ctx, converter, argument, param) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index cbdd4fcfe..5809a156d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -43,7 +43,7 @@ import discord from .errors import * from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping -from . import converter as converters +from .converter import run_converters, get_converter, Greedy from ._types import _BaseCommand from .cog import Cog @@ -175,7 +175,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect. continue annotation = _evaluate_annotation(annotation, globalns, globalns, cache) - if annotation is converters.Greedy: + if annotation is Greedy: raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') params[name] = parameter.replace(annotation=annotation) @@ -219,14 +219,6 @@ def hooked_wrapped_callback(command, ctx, coro): return ret return wrapped -def _convert_to_bool(argument): - lowered = argument.lower() - if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): - return True - elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): - return False - else: - raise BadBoolArgument(lowered) class _CaseInsensitiveDict(dict): def __contains__(self, k): @@ -541,113 +533,16 @@ class Command(_BaseCommand): finally: ctx.bot.dispatch('command_error', ctx, error) - async def _actual_conversion(self, ctx, converter, argument, param): - if converter is bool: - return _convert_to_bool(argument) - - try: - module = converter.__module__ - except AttributeError: - pass - else: - if module is not None and (module.startswith('discord.') and not module.endswith('converter')): - converter = getattr(converters, converter.__name__ + 'Converter', converter) - - try: - if inspect.isclass(converter) and issubclass(converter, converters.Converter): - if inspect.ismethod(converter.convert): - return await converter.convert(ctx, argument) - else: - return await converter().convert(ctx, argument) - elif isinstance(converter, converters.Converter): - return await converter.convert(ctx, argument) - except CommandError: - raise - except Exception as exc: - raise ConversionError(converter, exc) from exc - - try: - return converter(argument) - except CommandError: - raise - except Exception as exc: - try: - name = converter.__name__ - except AttributeError: - name = converter.__class__.__name__ - - raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc - - async def do_conversion(self, ctx, converter, argument, param): - origin = getattr(converter, '__origin__', None) - - if origin is Union: - errors = [] - _NoneType = type(None) - union_args = converter.__args__ - for conv in union_args: - # if we got to this part in the code, then the previous conversions have failed - # so we should just undo the view, return the default, and allow parsing to continue - # with the other parameters - if conv is _NoneType and param.kind != param.VAR_POSITIONAL: - ctx.view.undo() - return None if param.default is param.empty else param.default - - try: - value = await self.do_conversion(ctx, conv, argument, param) - except CommandError as exc: - errors.append(exc) - else: - return value - - # if we're here, then we failed all the converters - raise BadUnionArgument(param, union_args, errors) - - if origin is Literal: - errors = [] - conversions = {} - literal_args = converter.__args__ - for literal in literal_args: - literal_type = type(literal) - try: - value = conversions[literal_type] - except KeyError: - try: - value = await self._actual_conversion(ctx, literal_type, argument, param) - except CommandError as exc: - errors.append(exc) - conversions[literal_type] = object() - continue - else: - conversions[literal_type] = value - - if value == literal: - return value - - # if we're here, then we failed to match all the literals - raise BadLiteralArgument(param, literal_args, errors) - - return await self._actual_conversion(ctx, converter, argument, param) - - def _get_converter(self, param): - converter = param.annotation - if converter is param.empty: - if param.default is not param.empty: - converter = str if param.default is None else type(param.default) - else: - converter = str - return converter - async def transform(self, ctx, param): required = param.default is param.empty - converter = self._get_converter(param) + converter = get_converter(param) consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw view = ctx.view view.skip_ws() # The greedy converter is simple -- it keeps going until it fails in which case, # it undos the view ready for the next parameter to use instead - if isinstance(converter, converters.Greedy): + if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): return await self._transform_greedy_pos(ctx, param, required, converter.converter) elif param.kind == param.VAR_POSITIONAL: @@ -674,7 +569,7 @@ class Command(_BaseCommand): argument = view.get_quoted_word() view.previous = previous - return await self.do_conversion(ctx, converter, argument, param) + return await run_converters(ctx, converter, argument, param) async def _transform_greedy_pos(self, ctx, param, required, converter): view = ctx.view @@ -686,7 +581,7 @@ class Command(_BaseCommand): view.skip_ws() try: argument = view.get_quoted_word() - value = await self.do_conversion(ctx, converter, argument, param) + value = await run_converters(ctx, converter, argument, param) except (CommandError, ArgumentParsingError): view.index = previous break @@ -702,7 +597,7 @@ class Command(_BaseCommand): previous = view.index try: argument = view.get_quoted_word() - value = await self.do_conversion(ctx, converter, argument, param) + value = await run_converters(ctx, converter, argument, param) except (CommandError, ArgumentParsingError): view.index = previous raise RuntimeError() from None # break loop @@ -826,9 +721,9 @@ class Command(_BaseCommand): elif param.kind == param.KEYWORD_ONLY: # kwarg only param denotes "consume rest" semantics if self.rest_is_raw: - converter = self._get_converter(param) + converter = get_converter(param) argument = view.read_rest() - kwargs[name] = await self.do_conversion(ctx, converter, argument, param) + kwargs[name] = await run_converters(ctx, converter, argument, param) else: kwargs[name] = await self.transform(ctx, param) break @@ -1126,7 +1021,7 @@ class Command(_BaseCommand): result = [] for name, param in params.items(): - greedy = isinstance(param.annotation, converters.Greedy) + greedy = isinstance(param.annotation, Greedy) optional = False # postpone evaluation of if it's an optional argument # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index eb9300390..b8ef64cf3 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -329,6 +329,8 @@ Converters .. autoclass:: discord.ext.commands.Greedy() +.. autofunction:: discord.ext.commands.run_converters + .. _ext_commands_api_errors: Exceptions