Browse Source

Split annotation resolution to discord.utils

pull/6800/head
Rapptz 4 years ago
parent
commit
9f3551926a
  1. 99
      discord/ext/commands/core.py
  2. 2
      discord/ext/commands/flags.py
  3. 105
      discord/utils.py

99
discord/ext/commands/core.py

@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE.
from typing import (
Any,
Dict,
ForwardRef,
Iterable,
Literal,
Optional,
Tuple,
Union,
)
import asyncio
@ -37,7 +33,6 @@ import functools
import inspect
import datetime
import types
import sys
import discord
@ -74,102 +69,12 @@ __all__ = (
'bot_has_guild_permissions'
)
PY_310 = sys.version_info >= (3, 10)
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def _evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
locals: Dict[str, Any],
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals, locals)
cache[tp] = evaluated
return _evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.Union:
converted = Union[args] # type: ignore
return _evaluate_annotation(converted, globals, locals, cache)
return tp
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
is_literal = True
evaluated_args = tuple(
_evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
if evaluated_args == args:
return tp
try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return tp
def resolve_annotation(
annotation: Any,
globalns: Dict[str, Any],
localns: Optional[Dict[str, Any]],
cache: Optional[Dict[str, Any]],
) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
locals = globalns if localns is None else localns
if cache is None:
cache = {}
return _evaluate_annotation(annotation, globalns, locals, cache)
def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]:
globalns = function.__globals__
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
eval_annotation = discord.utils.evaluate_annotation
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.
params[name] = parameter.replace(annotation=type(None))
continue
annotation = _evaluate_annotation(annotation, globalns, globalns, cache)
annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')

2
discord/ext/commands/flags.py

@ -32,7 +32,7 @@ from .errors import (
MissingRequiredFlag,
)
from .core import resolve_annotation
from discord.utils import resolve_annotation
from .view import StringView
from .converter import run_converters

105
discord/utils.py

@ -31,13 +31,16 @@ from typing import (
AsyncIterator,
Callable,
Dict,
ForwardRef,
Generic,
Iterable,
Iterator,
List,
Literal,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter
import json
import re
import sys
import types
import warnings
from .errors import InvalidArgument
@ -99,6 +104,7 @@ if TYPE_CHECKING:
class _RequestLike(Protocol):
headers: Dict[str, Any]
else:
cached_property = _cached_property
@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
if ret:
yield ret
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
ret = []
n = 0
@ -767,9 +774,9 @@ def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T
def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
"""A helper function that collects an iterator into chunks of a given size.
.. versionadded:: 2.0
Parameters
----------
iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`]
@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
return _chunk(iterator, max_size)
PY_310 = sys.version_info >= (3, 10)
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
locals: Dict[str, Any],
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals, locals)
cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.Union:
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)
return tp
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
is_literal = True
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
if evaluated_args == args:
return tp
try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return tp
def resolve_annotation(
annotation: Any,
globalns: Dict[str, Any],
localns: Optional[Dict[str, Any]],
cache: Optional[Dict[str, Any]],
) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
locals = globalns if localns is None else localns
if cache is None:
cache = {}
return evaluate_annotation(annotation, globalns, locals, cache)

Loading…
Cancel
Save