Browse Source

[commands] Refactor typing evaluation to not use get_type_hints

get_type_hints had a few issues:

1. It would convert = None default parameters to Optional
2. It would not allow values as type annotations
3. It would not implicitly convert some string literals as ForwardRef

In Python 3.9 `list['Foo']` does not convert into
`list[ForwardRef('Foo')]` even though `typing.List` does this
behaviour. In order to streamline it, evaluation had to be rewritten
manually to support our usecases.

This patch also flattens nested typing.Literal which was not done
until Python 3.9.2.
pull/6698/head
Rapptz 4 years ago
parent
commit
3151672cfe
  1. 161
      discord/ext/commands/core.py

161
discord/ext/commands/core.py

@ -22,10 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import (
Any,
Dict,
ForwardRef,
Iterable,
Literal,
Tuple,
Union,
get_args as get_typing_args,
get_origin as get_typing_origin,
)
import asyncio
import functools
import inspect
import typing
import datetime
import sys
@ -64,6 +74,83 @@ __all__ = (
'bot_has_guild_permissions'
)
PY_310 = sys.version_info >= (3, 10)
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals)
cache[tp] = evaluated
return _evaluate_annotation(evaluated, globals, cache)
if hasattr(tp, '__args__'):
implicit_str = True
args = tp.__args__
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
evaluated_args = tuple(
_evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args
)
if evaluated_args == args:
return tp
try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return tp
def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
return _evaluate_annotation(annotation, globalns, cache)
def get_signature_parameters(function) -> Dict[str, inspect.Parameter]:
globalns = function.__globals__
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
params[name] = parameter
continue
if annotation is None:
params[name] = parameter.replace(annotation=type(None))
continue
annotation = _evaluate_annotation(annotation, globalns, cache)
if annotation is converters.Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
params[name] = parameter.replace(annotation=annotation)
return params
def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
@ -300,40 +387,7 @@ class Command(_BaseCommand):
def callback(self, function):
self._callback = function
self.module = function.__module__
signature = inspect.signature(function)
self.params = signature.parameters.copy()
# see: https://bugs.python.org/issue41341
resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved
try:
type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
except NameError as e:
raise NameError(f'unresolved forward reference: {e.args[0]}') from None
for key, value in self.params.items():
# coalesce the forward references
if key in type_hints:
self.params[key] = value = value.replace(annotation=type_hints[key])
# fail early for when someone passes an unparameterized Greedy type
if value.annotation is converters.Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
def _return_resolved(self, type, **kwargs):
return type
def _recursive_resolve(self, type, *, globals=None):
if not isinstance(type, typing.ForwardRef):
return type
resolved = eval(type.__forward_arg__, globals)
args = typing.get_args(resolved)
for index, arg in enumerate(args):
inner_resolve_result = self._recursive_resolve(arg, globals=globals)
resolved[index] = inner_resolve_result
return resolved
self.params = get_signature_parameters(function)
def add_check(self, func):
"""Adds a check to the command.
@ -493,12 +547,12 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param):
origin = typing.get_origin(converter)
origin = get_typing_origin(converter)
if origin is typing.Union:
if origin is Union:
errors = []
_NoneType = type(None)
for conv in typing.get_args(converter):
for conv in get_typing_args(converter):
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
@ -514,13 +568,12 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed all the converters
raise BadUnionArgument(param, typing.get_args(converter), errors)
raise BadUnionArgument(param, get_typing_args(converter), errors)
if origin is typing.Literal:
if origin is Literal:
errors = []
conversions = {}
literal_args = tuple(self._flattened_typing_literal_args(converter))
for literal in literal_args:
for literal in converter.__args__:
literal_type = type(literal)
try:
value = conversions[literal_type]
@ -538,7 +591,7 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed to match all the literals
raise BadLiteralArgument(param, literal_args, errors)
raise BadLiteralArgument(param, converter.__args__, errors)
return await self._actual_conversion(ctx, converter, argument, param)
@ -1021,14 +1074,7 @@ class Command(_BaseCommand):
return ''
def _is_typing_optional(self, annotation):
return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
def _flattened_typing_literal_args(self, annotation):
for literal in typing.get_args(annotation):
if typing.get_origin(literal) is typing.Literal:
yield from self._flattened_typing_literal_args(literal)
else:
yield literal
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
@property
def signature(self):
@ -1048,17 +1094,16 @@ class Command(_BaseCommand):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
origin = typing.get_origin(annotation)
if not greedy and origin is typing.Union:
union_args = typing.get_args(annotation)
origin = get_typing_origin(annotation)
if not greedy and origin is Union:
union_args = get_typing_args(annotation)
optional = union_args[-1] is type(None)
if optional:
annotation = union_args[0]
origin = typing.get_origin(annotation)
origin = get_typing_origin(annotation)
if origin is typing.Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
for v in self._flattened_typing_literal_args(annotation))
if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.

Loading…
Cancel
Save