Browse Source

Split annotation resolution to discord.utils

pull/6800/head
Rapptz 4 years ago
parent
commit
9f3551926a
  1. 99
      discord/ext/commands/core.py
  2. 2
      discord/ext/commands/flags.py
  3. 105
      discord/utils.py

99
discord/ext/commands/core.py

@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE.
from typing import ( from typing import (
Any, Any,
Dict, Dict,
ForwardRef,
Iterable,
Literal, Literal,
Optional,
Tuple,
Union, Union,
) )
import asyncio import asyncio
@ -37,7 +33,6 @@ import functools
import inspect import inspect
import datetime import datetime
import types import types
import sys
import discord import discord
@ -74,102 +69,12 @@ __all__ = (
'bot_has_guild_permissions' 'bot_has_guild_permissions'
) )
PY_310 = sys.version_info >= (3, 10)
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def _evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
locals: Dict[str, Any],
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals, locals)
cache[tp] = evaluated
return _evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.Union:
converted = Union[args] # type: ignore
return _evaluate_annotation(converted, globals, locals, cache)
return tp
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
is_literal = True
evaluated_args = tuple(
_evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
if evaluated_args == args:
return tp
try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return tp
def resolve_annotation(
annotation: Any,
globalns: Dict[str, Any],
localns: Optional[Dict[str, Any]],
cache: Optional[Dict[str, Any]],
) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
locals = globalns if localns is None else localns
if cache is None:
cache = {}
return _evaluate_annotation(annotation, globalns, locals, cache)
def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]: def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]:
globalns = function.__globals__ globalns = function.__globals__
signature = inspect.signature(function) signature = inspect.signature(function)
params = {} params = {}
cache: Dict[str, Any] = {} cache: Dict[str, Any] = {}
eval_annotation = discord.utils.evaluate_annotation
for name, parameter in signature.parameters.items(): for name, parameter in signature.parameters.items():
annotation = parameter.annotation annotation = parameter.annotation
if annotation is parameter.empty: if annotation is parameter.empty:
@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.
params[name] = parameter.replace(annotation=type(None)) params[name] = parameter.replace(annotation=type(None))
continue continue
annotation = _evaluate_annotation(annotation, globalns, globalns, cache) annotation = eval_annotation(annotation, globalns, globalns, cache)
if annotation is Greedy: if annotation is Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')

2
discord/ext/commands/flags.py

@ -32,7 +32,7 @@ from .errors import (
MissingRequiredFlag, MissingRequiredFlag,
) )
from .core import resolve_annotation from discord.utils import resolve_annotation
from .view import StringView from .view import StringView
from .converter import run_converters from .converter import run_converters

105
discord/utils.py

@ -31,13 +31,16 @@ from typing import (
AsyncIterator, AsyncIterator,
Callable, Callable,
Dict, Dict,
ForwardRef,
Generic, Generic,
Iterable, Iterable,
Iterator, Iterator,
List, List,
Literal,
Optional, Optional,
Protocol, Protocol,
Sequence, Sequence,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter from operator import attrgetter
import json import json
import re import re
import sys
import types
import warnings import warnings
from .errors import InvalidArgument from .errors import InvalidArgument
@ -99,6 +104,7 @@ if TYPE_CHECKING:
class _RequestLike(Protocol): class _RequestLike(Protocol):
headers: Dict[str, Any] headers: Dict[str, Any]
else: else:
cached_property = _cached_property cached_property = _cached_property
@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
if ret: if ret:
yield ret yield ret
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]: async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
ret = [] ret = []
n = 0 n = 0
@ -767,9 +774,9 @@ def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T
def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]: def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
"""A helper function that collects an iterator into chunks of a given size. """A helper function that collects an iterator into chunks of a given size.
.. versionadded:: 2.0 .. versionadded:: 2.0
Parameters Parameters
---------- ----------
iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`]
@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
if isinstance(iterator, AsyncIterator): if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size) return _achunk(iterator, max_size)
return _chunk(iterator, max_size) return _chunk(iterator, max_size)
PY_310 = sys.version_info >= (3, 10)
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
locals: Dict[str, Any],
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals, locals)
cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.Union:
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)
return tp
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
is_literal = True
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
if evaluated_args == args:
return tp
try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return tp
def resolve_annotation(
annotation: Any,
globalns: Dict[str, Any],
localns: Optional[Dict[str, Any]],
cache: Optional[Dict[str, Any]],
) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
locals = globalns if localns is None else localns
if cache is None:
cache = {}
return evaluate_annotation(annotation, globalns, locals, cache)

Loading…
Cancel
Save