From 7f91ae8b676908fec8b7a76aef6f86993871fb05 Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 11 Apr 2021 14:38:17 +1000 Subject: [PATCH] [commands] use __args__ and __origin__ where applicable --- discord/ext/commands/core.py | 41 +++++++++++++++++++++------------- discord/ext/commands/errors.py | 3 +-- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index ab8c52c1b..215c89b1e 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -30,8 +30,6 @@ from typing import ( Literal, Tuple, Union, - get_args as get_typing_args, - get_origin as get_typing_origin, ) import asyncio import functools @@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: params.append(p) return tuple(params) +def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + none_cls = type(None) + return tuple(p for p in parameters if p is not none_cls) + (none_cls,) + def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True): if isinstance(tp, ForwardRef): tp = tp.__forward_arg__ @@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] if hasattr(tp, '__args__'): implicit_str = True args = tp.__args__ + if tp.__origin__ is Union: + try: + if args.index(type(None)) != len(args) - 1: + args = normalise_optional_params(tp.__args__) + except ValueError: + pass if tp.__origin__ is Literal: if not PY_310: args = flatten_literal_params(tp.__args__) @@ -547,12 +555,13 @@ class Command(_BaseCommand): raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc async def do_conversion(self, ctx, converter, argument, param): - origin = get_typing_origin(converter) + origin = getattr(converter, '__origin__', None) if origin is Union: errors = [] _NoneType = type(None) - for conv in get_typing_args(converter): + union_args = converter.__args__ + for conv in union_args: # if we got to this part in the code, then the previous conversions have failed # so we should just undo the view, return the default, and allow parsing to continue # with the other parameters @@ -568,12 +577,13 @@ class Command(_BaseCommand): return value # if we're here, then we failed all the converters - raise BadUnionArgument(param, get_typing_args(converter), errors) + raise BadUnionArgument(param, union_args, errors) if origin is Literal: errors = [] conversions = {} - for literal in converter.__args__: + literal_args = converter.__args__ + for literal in literal_args: literal_type = type(literal) try: value = conversions[literal_type] @@ -591,7 +601,7 @@ class Command(_BaseCommand): return value # if we're here, then we failed to match all the literals - raise BadLiteralArgument(param, converter.__args__, errors) + raise BadLiteralArgument(param, literal_args, errors) return await self._actual_conversion(ctx, converter, argument, param) @@ -614,7 +624,7 @@ class Command(_BaseCommand): # The greedy converter is simple -- it keeps going until it fails in which case, # it undos the view ready for the next parameter to use instead if isinstance(converter, converters.Greedy): - if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): return await self._transform_greedy_pos(ctx, param, required, converter.converter) elif param.kind == param.VAR_POSITIONAL: return await self._transform_greedy_var_pos(ctx, param, converter.converter) @@ -782,7 +792,7 @@ class Command(_BaseCommand): raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') for name, param in iterator: - if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): transformed = await self.transform(ctx, param) args.append(transformed) elif param.kind == param.KEYWORD_ONLY: @@ -1074,7 +1084,7 @@ class Command(_BaseCommand): return '' def _is_typing_optional(self, annotation): - return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None) + return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ @property def signature(self): @@ -1094,13 +1104,14 @@ class Command(_BaseCommand): # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # parameter signature is a literal list of it's values annotation = param.annotation.converter if greedy else param.annotation - origin = get_typing_origin(annotation) + origin = getattr(annotation, '__origin__', None) if not greedy and origin is Union: - union_args = get_typing_args(annotation) - optional = union_args[-1] is type(None) - if optional: + none_cls = type(None) + union_args = annotation.__args__ + optional = union_args[-1] is none_cls + if len(union_args) == 2 and optional: annotation = union_args[0] - origin = get_typing_origin(annotation) + origin = getattr(annotation, '__origin__', None) if origin is Literal: name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index b825057ec..98154d108 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE. """ from discord.errors import ClientException, DiscordException -import typing __all__ = ( @@ -646,7 +645,7 @@ class BadUnionArgument(UserInputError): try: return x.__name__ except AttributeError: - if typing.get_origin(x) is not None: + if hasattr(x, '__origin__'): return repr(x) return x.__class__.__name__