diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index fcf58addd..cb986e3b2 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE. from typing import ( Any, Dict, - ForwardRef, - Iterable, Literal, - Optional, - Tuple, Union, ) import asyncio @@ -37,7 +33,6 @@ import functools import inspect import datetime import types -import sys import discord @@ -74,102 +69,12 @@ __all__ = ( 'bot_has_guild_permissions' ) -PY_310 = sys.version_info >= (3, 10) - -def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: - params = [] - literal_cls = type(Literal[0]) - for p in parameters: - if isinstance(p, literal_cls): - params.extend(p.__args__) - else: - params.append(p) - return tuple(params) - -def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: - none_cls = type(None) - return tuple(p for p in parameters if p is not none_cls) + (none_cls,) - -def _evaluate_annotation( - tp: Any, - globals: Dict[str, Any], - locals: Dict[str, Any], - cache: Dict[str, Any], - *, - implicit_str: bool = True, -): - if isinstance(tp, ForwardRef): - tp = tp.__forward_arg__ - # ForwardRefs always evaluate their internals - implicit_str = True - - if implicit_str and isinstance(tp, str): - if tp in cache: - return cache[tp] - evaluated = eval(tp, globals, locals) - cache[tp] = evaluated - return _evaluate_annotation(evaluated, globals, locals, cache) - - if hasattr(tp, '__args__'): - implicit_str = True - is_literal = False - args = tp.__args__ - if not hasattr(tp, '__origin__'): - if PY_310 and tp.__class__ is types.Union: - converted = Union[args] # type: ignore - return _evaluate_annotation(converted, globals, locals, cache) - - return tp - if tp.__origin__ is Union: - try: - if args.index(type(None)) != len(args) - 1: - args = normalise_optional_params(tp.__args__) - except ValueError: - pass - if tp.__origin__ is Literal: - if not PY_310: - args = flatten_literal_params(tp.__args__) - implicit_str = False - is_literal = True - - evaluated_args = tuple( - _evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args - ) - - if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): - raise TypeError('Literal arguments must be of type str, int, bool, float or complex.') - - if evaluated_args == args: - return tp - - try: - return tp.copy_with(evaluated_args) - except AttributeError: - return tp.__origin__[evaluated_args] - - return tp - -def resolve_annotation( - annotation: Any, - globalns: Dict[str, Any], - localns: Optional[Dict[str, Any]], - cache: Optional[Dict[str, Any]], -) -> Any: - if annotation is None: - return type(None) - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - - locals = globalns if localns is None else localns - if cache is None: - cache = {} - return _evaluate_annotation(annotation, globalns, locals, cache) - def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]: globalns = function.__globals__ signature = inspect.signature(function) params = {} cache: Dict[str, Any] = {} + eval_annotation = discord.utils.evaluate_annotation for name, parameter in signature.parameters.items(): annotation = parameter.annotation if annotation is parameter.empty: @@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect. params[name] = parameter.replace(annotation=type(None)) continue - annotation = _evaluate_annotation(annotation, globalns, globalns, cache) + annotation = eval_annotation(annotation, globalns, globalns, cache) if annotation is Greedy: raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index e58c9ce5b..3aa9a65ff 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -32,7 +32,7 @@ from .errors import ( MissingRequiredFlag, ) -from .core import resolve_annotation +from discord.utils import resolve_annotation from .view import StringView from .converter import run_converters diff --git a/discord/utils.py b/discord/utils.py index 88948da57..293103cd7 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -31,13 +31,16 @@ from typing import ( AsyncIterator, Callable, Dict, + ForwardRef, Generic, Iterable, Iterator, List, + Literal, Optional, Protocol, Sequence, + Tuple, Type, TypeVar, Union, @@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature from operator import attrgetter import json import re +import sys +import types import warnings from .errors import InvalidArgument @@ -99,6 +104,7 @@ if TYPE_CHECKING: class _RequestLike(Protocol): headers: Dict[str, Any] + else: cached_property = _cached_property @@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]: if ret: yield ret + async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]: ret = [] n = 0 @@ -767,9 +774,9 @@ def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]: """A helper function that collects an iterator into chunks of a given size. - + .. versionadded:: 2.0 - + Parameters ---------- iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] @@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]: if isinstance(iterator, AsyncIterator): return _achunk(iterator, max_size) return _chunk(iterator, max_size) + + +PY_310 = sys.version_info >= (3, 10) + + +def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + params = [] + literal_cls = type(Literal[0]) + for p in parameters: + if isinstance(p, literal_cls): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + + +def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + none_cls = type(None) + return tuple(p for p in parameters if p is not none_cls) + (none_cls,) + + +def evaluate_annotation( + tp: Any, + globals: Dict[str, Any], + locals: Dict[str, Any], + cache: Dict[str, Any], + *, + implicit_str: bool = True, +): + if isinstance(tp, ForwardRef): + tp = tp.__forward_arg__ + # ForwardRefs always evaluate their internals + implicit_str = True + + if implicit_str and isinstance(tp, str): + if tp in cache: + return cache[tp] + evaluated = eval(tp, globals, locals) + cache[tp] = evaluated + return evaluate_annotation(evaluated, globals, locals, cache) + + if hasattr(tp, '__args__'): + implicit_str = True + is_literal = False + args = tp.__args__ + if not hasattr(tp, '__origin__'): + if PY_310 and tp.__class__ is types.Union: + converted = Union[args] # type: ignore + return evaluate_annotation(converted, globals, locals, cache) + + return tp + if tp.__origin__ is Union: + try: + if args.index(type(None)) != len(args) - 1: + args = normalise_optional_params(tp.__args__) + except ValueError: + pass + if tp.__origin__ is Literal: + if not PY_310: + args = flatten_literal_params(tp.__args__) + implicit_str = False + is_literal = True + + evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args) + + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError('Literal arguments must be of type str, int, bool, float or complex.') + + if evaluated_args == args: + return tp + + try: + return tp.copy_with(evaluated_args) + except AttributeError: + return tp.__origin__[evaluated_args] + + return tp + + +def resolve_annotation( + annotation: Any, + globalns: Dict[str, Any], + localns: Optional[Dict[str, Any]], + cache: Optional[Dict[str, Any]], +) -> Any: + if annotation is None: + return type(None) + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + + locals = globalns if localns is None else localns + if cache is None: + cache = {} + return evaluate_annotation(annotation, globalns, locals, cache)