Browse Source

[commands] use __args__ and __origin__ where applicable

pull/6701/head
Josh 4 years ago
committed by GitHub
parent
commit
7f91ae8b67
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 41
      discord/ext/commands/core.py
  2. 3
      discord/ext/commands/errors.py

41
discord/ext/commands/core.py

@ -30,8 +30,6 @@ from typing import (
Literal,
Tuple,
Union,
get_args as get_typing_args,
get_origin as get_typing_origin,
)
import asyncio
import functools
@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params.append(p)
return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any]
if hasattr(tp, '__args__'):
implicit_str = True
args = tp.__args__
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
@ -547,12 +555,13 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param):
origin = get_typing_origin(converter)
origin = getattr(converter, '__origin__', None)
if origin is Union:
errors = []
_NoneType = type(None)
for conv in get_typing_args(converter):
union_args = converter.__args__
for conv in union_args:
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
@ -568,12 +577,13 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed all the converters
raise BadUnionArgument(param, get_typing_args(converter), errors)
raise BadUnionArgument(param, union_args, errors)
if origin is Literal:
errors = []
conversions = {}
for literal in converter.__args__:
literal_args = converter.__args__
for literal in literal_args:
literal_type = type(literal)
try:
value = conversions[literal_type]
@ -591,7 +601,7 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed to match all the literals
raise BadLiteralArgument(param, converter.__args__, errors)
raise BadLiteralArgument(param, literal_args, errors)
return await self._actual_conversion(ctx, converter, argument, param)
@ -614,7 +624,7 @@ class Command(_BaseCommand):
# The greedy converter is simple -- it keeps going until it fails in which case,
# it undos the view ready for the next parameter to use instead
if isinstance(converter, converters.Greedy):
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
elif param.kind == param.VAR_POSITIONAL:
return await self._transform_greedy_var_pos(ctx, param, converter.converter)
@ -782,7 +792,7 @@ class Command(_BaseCommand):
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
for name, param in iterator:
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
transformed = await self.transform(ctx, param)
args.append(transformed)
elif param.kind == param.KEYWORD_ONLY:
@ -1074,7 +1084,7 @@ class Command(_BaseCommand):
return ''
def _is_typing_optional(self, annotation):
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__
@property
def signature(self):
@ -1094,13 +1104,14 @@ class Command(_BaseCommand):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
origin = get_typing_origin(annotation)
origin = getattr(annotation, '__origin__', None)
if not greedy and origin is Union:
union_args = get_typing_args(annotation)
optional = union_args[-1] is type(None)
if optional:
none_cls = type(None)
union_args = annotation.__args__
optional = union_args[-1] is none_cls
if len(union_args) == 2 and optional:
annotation = union_args[0]
origin = get_typing_origin(annotation)
origin = getattr(annotation, '__origin__', None)
if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)

3
discord/ext/commands/errors.py

@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE.
"""
from discord.errors import ClientException, DiscordException
import typing
__all__ = (
@ -646,7 +645,7 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
if typing.get_origin(x) is not None:
if hasattr(x, '__origin__'):
return repr(x)
return x.__class__.__name__

Loading…
Cancel
Save