Browse Source

[commands]Add typing.Literal converter

pull/6692/head
Sigmath Bits 4 years ago
committed by GitHub
parent
commit
68aef92b37
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 95
      discord/ext/commands/core.py
  2. 34
      discord/ext/commands/errors.py
  3. 21
      docs/ext/commands/commands.rst

95
discord/ext/commands/core.py

@ -489,31 +489,52 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param):
try:
origin = converter.__origin__
except AttributeError:
pass
else:
if origin is typing.Union:
errors = []
_NoneType = type(None)
for conv in converter.__args__:
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
ctx.view.undo()
return None if param.default is param.empty else param.default
origin = typing.get_origin(converter)
if origin is typing.Union:
errors = []
_NoneType = type(None)
for conv in typing.get_args(converter):
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
ctx.view.undo()
return None if param.default is param.empty else param.default
try:
value = await self.do_conversion(ctx, conv, argument, param)
except CommandError as exc:
errors.append(exc)
else:
return value
# if we're here, then we failed all the converters
raise BadUnionArgument(param, typing.get_args(converter), errors)
if origin is typing.Literal:
errors = []
conversions = {}
literal_args = tuple(self._flattened_typing_literal_args(converter))
for literal in literal_args:
literal_type = type(literal)
try:
value = conversions[literal_type]
except KeyError:
try:
value = await self._actual_conversion(ctx, conv, argument, param)
value = await self._actual_conversion(ctx, literal_type, argument, param)
except CommandError as exc:
errors.append(exc)
conversions[literal_type] = object()
continue
else:
return value
conversions[literal_type] = value
if value == literal:
return value
# if we're here, then we failed all the converters
raise BadUnionArgument(param, converter.__args__, errors)
# if we're here, then we failed to match all the literals
raise BadLiteralArgument(param, literal_args, errors)
return await self._actual_conversion(ctx, converter, argument, param)
@ -995,15 +1016,14 @@ class Command(_BaseCommand):
return ''
def _is_typing_optional(self, annotation):
try:
origin = annotation.__origin__
except AttributeError:
return False
return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
if origin is not typing.Union:
return False
return annotation.__args__[-1] is type(None)
def _flattened_typing_literal_args(self, annotation):
for literal in typing.get_args(annotation):
if typing.get_origin(literal) is typing.Literal:
yield from self._flattened_typing_literal_args(literal)
else:
yield literal
@property
def signature(self):
@ -1011,7 +1031,6 @@ class Command(_BaseCommand):
if self.usage is not None:
return self.usage
params = self.clean_params
if not params:
return ''
@ -1019,6 +1038,22 @@ class Command(_BaseCommand):
result = []
for name, param in params.items():
greedy = isinstance(param.annotation, converters._Greedy)
optional = False # postpone evaluation of if it's an optional argument
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
origin = typing.get_origin(annotation)
if not greedy and origin is typing.Union:
union_args = typing.get_args(annotation)
optional = union_args[-1] is type(None)
if optional:
annotation = union_args[0]
origin = typing.get_origin(annotation)
if origin is typing.Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
for v in self._flattened_typing_literal_args(annotation))
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
@ -1038,7 +1073,7 @@ class Command(_BaseCommand):
result.append(f'[{name}...]')
elif greedy:
result.append(f'[{name}]...')
elif self._is_typing_optional(param.annotation):
elif optional:
result.append(f'[{name}]')
else:
result.append(f'<{name}>')

34
discord/ext/commands/errors.py

@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from discord.errors import ClientException, DiscordException
import typing
__all__ = (
@ -62,6 +63,7 @@ __all__ = (
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
'BadLiteralArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
@ -644,6 +646,8 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
if typing.get_origin(x) is not None:
return repr(x)
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
@ -654,6 +658,36 @@ class BadUnionArgument(UserInputError):
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all
its associated values.
This inherits from :exc:`UserInputError`
.. versionadded:: 2.0
Attributes
-----------
param: :class:`inspect.Parameter`
The parameter that failed being converted.
literals: Tuple[Any, ...]
A tuple of values compared against in conversion, in order of failure.
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param, literals, errors):
self.param = param
self.literals = literals
self.errors = errors
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.

21
docs/ext/commands/commands.rst

@ -502,6 +502,27 @@ resumes handling, which in this case would be to pass it into the ``liquid`` par
This converter only works in regular positional parameters, not variable parameters or keyword-only parameters.
typing.Literal
^^^^^^^^^^^^^^^^
A :data:`typing.Literal` is a special type hint that requires the passed parameter to be equal to one of the listed values
after being converted to the same type. For example, given the following:
.. code-block:: python3
from typing import Literal
@bot.command()
async def shop(ctx, buy_sell: Literal['buy', 'sell'], amount: Literal[1, 2], *, item: str):
await ctx.send(f'{buy_sell.capitalize()}ing {amount} {item}(s)!')
The ``buy_sell`` parameter must be either the literal string ``"buy"`` or ``"sell"`` and ``amount`` must convert to the
``int`` ``1`` or ``2``. If ``buy_sell`` or ``amount`` don't match any value, then a special error is raised,
:exc:`~.ext.commands.BadLiteralArgument`. Any literal values can be mixed and matched within the same :data:`typing.Literal` converter.
Note that ``typing.Literal[True]`` and ``typing.Literal[False]`` still follow the :class:`bool` converter rules.
Greedy
^^^^^^^^

Loading…
Cancel
Save