From 68aef92b377f61ed465660646659d4ba0100c314 Mon Sep 17 00:00:00 2001 From: Sigmath Bits <54879730+SigmathBits@users.noreply.github.com> Date: Sat, 10 Apr 2021 18:50:59 +1200 Subject: [PATCH] [commands]Add typing.Literal converter --- discord/ext/commands/core.py | 95 +++++++++++++++++++++++----------- discord/ext/commands/errors.py | 34 ++++++++++++ docs/ext/commands/commands.rst | 21 ++++++++ 3 files changed, 120 insertions(+), 30 deletions(-) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index a570ee48c..cce6f30cc 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -489,31 +489,52 @@ class Command(_BaseCommand): raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc async def do_conversion(self, ctx, converter, argument, param): - try: - origin = converter.__origin__ - except AttributeError: - pass - else: - if origin is typing.Union: - errors = [] - _NoneType = type(None) - for conv in converter.__args__: - # if we got to this part in the code, then the previous conversions have failed - # so we should just undo the view, return the default, and allow parsing to continue - # with the other parameters - if conv is _NoneType and param.kind != param.VAR_POSITIONAL: - ctx.view.undo() - return None if param.default is param.empty else param.default - + origin = typing.get_origin(converter) + + if origin is typing.Union: + errors = [] + _NoneType = type(None) + for conv in typing.get_args(converter): + # if we got to this part in the code, then the previous conversions have failed + # so we should just undo the view, return the default, and allow parsing to continue + # with the other parameters + if conv is _NoneType and param.kind != param.VAR_POSITIONAL: + ctx.view.undo() + return None if param.default is param.empty else param.default + + try: + value = await self.do_conversion(ctx, conv, argument, param) + except CommandError as exc: + errors.append(exc) + else: + return value + + # if we're here, then we failed all the converters + raise BadUnionArgument(param, typing.get_args(converter), errors) + + if origin is typing.Literal: + errors = [] + conversions = {} + literal_args = tuple(self._flattened_typing_literal_args(converter)) + for literal in literal_args: + literal_type = type(literal) + try: + value = conversions[literal_type] + except KeyError: try: - value = await self._actual_conversion(ctx, conv, argument, param) + value = await self._actual_conversion(ctx, literal_type, argument, param) except CommandError as exc: errors.append(exc) + conversions[literal_type] = object() + continue else: - return value + conversions[literal_type] = value + + if value == literal: + return value - # if we're here, then we failed all the converters - raise BadUnionArgument(param, converter.__args__, errors) + # if we're here, then we failed to match all the literals + raise BadLiteralArgument(param, literal_args, errors) return await self._actual_conversion(ctx, converter, argument, param) @@ -995,15 +1016,14 @@ class Command(_BaseCommand): return '' def _is_typing_optional(self, annotation): - try: - origin = annotation.__origin__ - except AttributeError: - return False + return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None) - if origin is not typing.Union: - return False - - return annotation.__args__[-1] is type(None) + def _flattened_typing_literal_args(self, annotation): + for literal in typing.get_args(annotation): + if typing.get_origin(literal) is typing.Literal: + yield from self._flattened_typing_literal_args(literal) + else: + yield literal @property def signature(self): @@ -1011,7 +1031,6 @@ class Command(_BaseCommand): if self.usage is not None: return self.usage - params = self.clean_params if not params: return '' @@ -1019,6 +1038,22 @@ class Command(_BaseCommand): result = [] for name, param in params.items(): greedy = isinstance(param.annotation, converters._Greedy) + optional = False # postpone evaluation of if it's an optional argument + + # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the + # parameter signature is a literal list of it's values + annotation = param.annotation.converter if greedy else param.annotation + origin = typing.get_origin(annotation) + if not greedy and origin is typing.Union: + union_args = typing.get_args(annotation) + optional = union_args[-1] is type(None) + if optional: + annotation = union_args[0] + origin = typing.get_origin(annotation) + + if origin is typing.Literal: + name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) + for v in self._flattened_typing_literal_args(annotation)) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should @@ -1038,7 +1073,7 @@ class Command(_BaseCommand): result.append(f'[{name}...]') elif greedy: result.append(f'[{name}]...') - elif self._is_typing_optional(param.annotation): + elif optional: result.append(f'[{name}]') else: result.append(f'<{name}>') diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index f8a2724d0..f7e745e0e 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from discord.errors import ClientException, DiscordException +import typing __all__ = ( @@ -62,6 +63,7 @@ __all__ = ( 'NSFWChannelRequired', 'ConversionError', 'BadUnionArgument', + 'BadLiteralArgument', 'ArgumentParsingError', 'UnexpectedQuoteError', 'InvalidEndOfQuotedStringError', @@ -644,6 +646,8 @@ class BadUnionArgument(UserInputError): try: return x.__name__ except AttributeError: + if typing.get_origin(x) is not None: + return repr(x) return x.__class__.__name__ to_string = [_get_name(x) for x in converters] @@ -654,6 +658,36 @@ class BadUnionArgument(UserInputError): super().__init__(f'Could not convert "{param.name}" into {fmt}.') +class BadLiteralArgument(UserInputError): + """Exception raised when a :data:`typing.Literal` converter fails for all + its associated values. + + This inherits from :exc:`UserInputError` + + .. versionadded:: 2.0 + + Attributes + ----------- + param: :class:`inspect.Parameter` + The parameter that failed being converted. + literals: Tuple[Any, ...] + A tuple of values compared against in conversion, in order of failure. + errors: List[:class:`CommandError`] + A list of errors that were caught from failing the conversion. + """ + def __init__(self, param, literals, errors): + self.param = param + self.literals = literals + self.errors = errors + + to_string = [repr(l) for l in literals] + if len(to_string) > 2: + fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) + else: + fmt = ' or '.join(to_string) + + super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.') + class ArgumentParsingError(UserInputError): """An exception raised when the parser fails to parse a user's input. diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index 3c5725369..837e00444 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -502,6 +502,27 @@ resumes handling, which in this case would be to pass it into the ``liquid`` par This converter only works in regular positional parameters, not variable parameters or keyword-only parameters. +typing.Literal +^^^^^^^^^^^^^^^^ + +A :data:`typing.Literal` is a special type hint that requires the passed parameter to be equal to one of the listed values +after being converted to the same type. For example, given the following: + +.. code-block:: python3 + + from typing import Literal + + @bot.command() + async def shop(ctx, buy_sell: Literal['buy', 'sell'], amount: Literal[1, 2], *, item: str): + await ctx.send(f'{buy_sell.capitalize()}ing {amount} {item}(s)!') + + +The ``buy_sell`` parameter must be either the literal string ``"buy"`` or ``"sell"`` and ``amount`` must convert to the +``int`` ``1`` or ``2``. If ``buy_sell`` or ``amount`` don't match any value, then a special error is raised, +:exc:`~.ext.commands.BadLiteralArgument`. Any literal values can be mixed and matched within the same :data:`typing.Literal` converter. + +Note that ``typing.Literal[True]`` and ``typing.Literal[False]`` still follow the :class:`bool` converter rules. + Greedy ^^^^^^^^