Browse Source

[commands] use __args__ and __origin__ where applicable

pull/6701/head
Josh 4 years ago
committed by GitHub
parent
commit
7f91ae8b67
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 41
      discord/ext/commands/core.py
  2. 3
      discord/ext/commands/errors.py

41
discord/ext/commands/core.py

@ -30,8 +30,6 @@ from typing import (
Literal, Literal,
Tuple, Tuple,
Union, Union,
get_args as get_typing_args,
get_origin as get_typing_origin,
) )
import asyncio import asyncio
import functools import functools
@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params.append(p) params.append(p)
return tuple(params) return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True): def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
if isinstance(tp, ForwardRef): if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__ tp = tp.__forward_arg__
@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any]
if hasattr(tp, '__args__'): if hasattr(tp, '__args__'):
implicit_str = True implicit_str = True
args = tp.__args__ args = tp.__args__
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal: if tp.__origin__ is Literal:
if not PY_310: if not PY_310:
args = flatten_literal_params(tp.__args__) args = flatten_literal_params(tp.__args__)
@ -547,12 +555,13 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param): async def do_conversion(self, ctx, converter, argument, param):
origin = get_typing_origin(converter) origin = getattr(converter, '__origin__', None)
if origin is Union: if origin is Union:
errors = [] errors = []
_NoneType = type(None) _NoneType = type(None)
for conv in get_typing_args(converter): union_args = converter.__args__
for conv in union_args:
# if we got to this part in the code, then the previous conversions have failed # if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue # so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters # with the other parameters
@ -568,12 +577,13 @@ class Command(_BaseCommand):
return value return value
# if we're here, then we failed all the converters # if we're here, then we failed all the converters
raise BadUnionArgument(param, get_typing_args(converter), errors) raise BadUnionArgument(param, union_args, errors)
if origin is Literal: if origin is Literal:
errors = [] errors = []
conversions = {} conversions = {}
for literal in converter.__args__: literal_args = converter.__args__
for literal in literal_args:
literal_type = type(literal) literal_type = type(literal)
try: try:
value = conversions[literal_type] value = conversions[literal_type]
@ -591,7 +601,7 @@ class Command(_BaseCommand):
return value return value
# if we're here, then we failed to match all the literals # if we're here, then we failed to match all the literals
raise BadLiteralArgument(param, converter.__args__, errors) raise BadLiteralArgument(param, literal_args, errors)
return await self._actual_conversion(ctx, converter, argument, param) return await self._actual_conversion(ctx, converter, argument, param)
@ -614,7 +624,7 @@ class Command(_BaseCommand):
# The greedy converter is simple -- it keeps going until it fails in which case, # The greedy converter is simple -- it keeps going until it fails in which case,
# it undos the view ready for the next parameter to use instead # it undos the view ready for the next parameter to use instead
if isinstance(converter, converters.Greedy): if isinstance(converter, converters.Greedy):
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, required, converter.converter) return await self._transform_greedy_pos(ctx, param, required, converter.converter)
elif param.kind == param.VAR_POSITIONAL: elif param.kind == param.VAR_POSITIONAL:
return await self._transform_greedy_var_pos(ctx, param, converter.converter) return await self._transform_greedy_var_pos(ctx, param, converter.converter)
@ -782,7 +792,7 @@ class Command(_BaseCommand):
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
for name, param in iterator: for name, param in iterator:
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
transformed = await self.transform(ctx, param) transformed = await self.transform(ctx, param)
args.append(transformed) args.append(transformed)
elif param.kind == param.KEYWORD_ONLY: elif param.kind == param.KEYWORD_ONLY:
@ -1074,7 +1084,7 @@ class Command(_BaseCommand):
return '' return ''
def _is_typing_optional(self, annotation): def _is_typing_optional(self, annotation):
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None) return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__
@property @property
def signature(self): def signature(self):
@ -1094,13 +1104,14 @@ class Command(_BaseCommand):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values # parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation annotation = param.annotation.converter if greedy else param.annotation
origin = get_typing_origin(annotation) origin = getattr(annotation, '__origin__', None)
if not greedy and origin is Union: if not greedy and origin is Union:
union_args = get_typing_args(annotation) none_cls = type(None)
optional = union_args[-1] is type(None) union_args = annotation.__args__
if optional: optional = union_args[-1] is none_cls
if len(union_args) == 2 and optional:
annotation = union_args[0] annotation = union_args[0]
origin = get_typing_origin(annotation) origin = getattr(annotation, '__origin__', None)
if origin is Literal: if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)

3
discord/ext/commands/errors.py

@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE.
""" """
from discord.errors import ClientException, DiscordException from discord.errors import ClientException, DiscordException
import typing
__all__ = ( __all__ = (
@ -646,7 +645,7 @@ class BadUnionArgument(UserInputError):
try: try:
return x.__name__ return x.__name__
except AttributeError: except AttributeError:
if typing.get_origin(x) is not None: if hasattr(x, '__origin__'):
return repr(x) return repr(x)
return x.__class__.__name__ return x.__class__.__name__

Loading…
Cancel
Save