diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 9bd8f4f81..ab8c52c1b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -22,10 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import ( + Any, + Dict, + ForwardRef, + Iterable, + Literal, + Tuple, + Union, + get_args as get_typing_args, + get_origin as get_typing_origin, +) import asyncio import functools import inspect -import typing import datetime import sys @@ -64,6 +74,83 @@ __all__ = ( 'bot_has_guild_permissions' ) +PY_310 = sys.version_info >= (3, 10) + +def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + params = [] + literal_cls = type(Literal[0]) + for p in parameters: + if isinstance(p, literal_cls): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + +def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True): + if isinstance(tp, ForwardRef): + tp = tp.__forward_arg__ + # ForwardRefs always evaluate their internals + implicit_str = True + + if implicit_str and isinstance(tp, str): + if tp in cache: + return cache[tp] + evaluated = eval(tp, globals) + cache[tp] = evaluated + return _evaluate_annotation(evaluated, globals, cache) + + if hasattr(tp, '__args__'): + implicit_str = True + args = tp.__args__ + if tp.__origin__ is Literal: + if not PY_310: + args = flatten_literal_params(tp.__args__) + implicit_str = False + + evaluated_args = tuple( + _evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args + ) + + if evaluated_args == args: + return tp + + try: + return tp.copy_with(evaluated_args) + except AttributeError: + return tp.__origin__[evaluated_args] + + return tp + +def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any: + if annotation is None: + return type(None) + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + return _evaluate_annotation(annotation, globalns, cache) + +def get_signature_parameters(function) -> Dict[str, inspect.Parameter]: + globalns = function.__globals__ + signature = inspect.signature(function) + params = {} + cache: Dict[str, Any] = {} + for name, parameter in signature.parameters.items(): + annotation = parameter.annotation + if annotation is parameter.empty: + params[name] = parameter + continue + if annotation is None: + params[name] = parameter.replace(annotation=type(None)) + continue + + annotation = _evaluate_annotation(annotation, globalns, cache) + if annotation is converters.Greedy: + raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') + + params[name] = parameter.replace(annotation=annotation) + + return params + + def wrap_callback(coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -300,40 +387,7 @@ class Command(_BaseCommand): def callback(self, function): self._callback = function self.module = function.__module__ - - signature = inspect.signature(function) - self.params = signature.parameters.copy() - - # see: https://bugs.python.org/issue41341 - resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved - - try: - type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()} - except NameError as e: - raise NameError(f'unresolved forward reference: {e.args[0]}') from None - - for key, value in self.params.items(): - # coalesce the forward references - if key in type_hints: - self.params[key] = value = value.replace(annotation=type_hints[key]) - - # fail early for when someone passes an unparameterized Greedy type - if value.annotation is converters.Greedy: - raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') - - def _return_resolved(self, type, **kwargs): - return type - - def _recursive_resolve(self, type, *, globals=None): - if not isinstance(type, typing.ForwardRef): - return type - - resolved = eval(type.__forward_arg__, globals) - args = typing.get_args(resolved) - for index, arg in enumerate(args): - inner_resolve_result = self._recursive_resolve(arg, globals=globals) - resolved[index] = inner_resolve_result - return resolved + self.params = get_signature_parameters(function) def add_check(self, func): """Adds a check to the command. @@ -493,12 +547,12 @@ class Command(_BaseCommand): raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc async def do_conversion(self, ctx, converter, argument, param): - origin = typing.get_origin(converter) + origin = get_typing_origin(converter) - if origin is typing.Union: + if origin is Union: errors = [] _NoneType = type(None) - for conv in typing.get_args(converter): + for conv in get_typing_args(converter): # if we got to this part in the code, then the previous conversions have failed # so we should just undo the view, return the default, and allow parsing to continue # with the other parameters @@ -514,13 +568,12 @@ class Command(_BaseCommand): return value # if we're here, then we failed all the converters - raise BadUnionArgument(param, typing.get_args(converter), errors) + raise BadUnionArgument(param, get_typing_args(converter), errors) - if origin is typing.Literal: + if origin is Literal: errors = [] conversions = {} - literal_args = tuple(self._flattened_typing_literal_args(converter)) - for literal in literal_args: + for literal in converter.__args__: literal_type = type(literal) try: value = conversions[literal_type] @@ -538,7 +591,7 @@ class Command(_BaseCommand): return value # if we're here, then we failed to match all the literals - raise BadLiteralArgument(param, literal_args, errors) + raise BadLiteralArgument(param, converter.__args__, errors) return await self._actual_conversion(ctx, converter, argument, param) @@ -1021,14 +1074,7 @@ class Command(_BaseCommand): return '' def _is_typing_optional(self, annotation): - return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None) - - def _flattened_typing_literal_args(self, annotation): - for literal in typing.get_args(annotation): - if typing.get_origin(literal) is typing.Literal: - yield from self._flattened_typing_literal_args(literal) - else: - yield literal + return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None) @property def signature(self): @@ -1048,17 +1094,16 @@ class Command(_BaseCommand): # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # parameter signature is a literal list of it's values annotation = param.annotation.converter if greedy else param.annotation - origin = typing.get_origin(annotation) - if not greedy and origin is typing.Union: - union_args = typing.get_args(annotation) + origin = get_typing_origin(annotation) + if not greedy and origin is Union: + union_args = get_typing_args(annotation) optional = union_args[-1] is type(None) if optional: annotation = union_args[0] - origin = typing.get_origin(annotation) + origin = get_typing_origin(annotation) - if origin is typing.Literal: - name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) - for v in self._flattened_typing_literal_args(annotation)) + if origin is Literal: + name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user.