|
|
@ -30,8 +30,6 @@ from typing import ( |
|
|
|
Literal, |
|
|
|
Tuple, |
|
|
|
Union, |
|
|
|
get_args as get_typing_args, |
|
|
|
get_origin as get_typing_origin, |
|
|
|
) |
|
|
|
import asyncio |
|
|
|
import functools |
|
|
@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: |
|
|
|
params.append(p) |
|
|
|
return tuple(params) |
|
|
|
|
|
|
|
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: |
|
|
|
none_cls = type(None) |
|
|
|
return tuple(p for p in parameters if p is not none_cls) + (none_cls,) |
|
|
|
|
|
|
|
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True): |
|
|
|
if isinstance(tp, ForwardRef): |
|
|
|
tp = tp.__forward_arg__ |
|
|
@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] |
|
|
|
if hasattr(tp, '__args__'): |
|
|
|
implicit_str = True |
|
|
|
args = tp.__args__ |
|
|
|
if tp.__origin__ is Union: |
|
|
|
try: |
|
|
|
if args.index(type(None)) != len(args) - 1: |
|
|
|
args = normalise_optional_params(tp.__args__) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
if tp.__origin__ is Literal: |
|
|
|
if not PY_310: |
|
|
|
args = flatten_literal_params(tp.__args__) |
|
|
@ -547,12 +555,13 @@ class Command(_BaseCommand): |
|
|
|
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc |
|
|
|
|
|
|
|
async def do_conversion(self, ctx, converter, argument, param): |
|
|
|
origin = get_typing_origin(converter) |
|
|
|
origin = getattr(converter, '__origin__', None) |
|
|
|
|
|
|
|
if origin is Union: |
|
|
|
errors = [] |
|
|
|
_NoneType = type(None) |
|
|
|
for conv in get_typing_args(converter): |
|
|
|
union_args = converter.__args__ |
|
|
|
for conv in union_args: |
|
|
|
# if we got to this part in the code, then the previous conversions have failed |
|
|
|
# so we should just undo the view, return the default, and allow parsing to continue |
|
|
|
# with the other parameters |
|
|
@ -568,12 +577,13 @@ class Command(_BaseCommand): |
|
|
|
return value |
|
|
|
|
|
|
|
# if we're here, then we failed all the converters |
|
|
|
raise BadUnionArgument(param, get_typing_args(converter), errors) |
|
|
|
raise BadUnionArgument(param, union_args, errors) |
|
|
|
|
|
|
|
if origin is Literal: |
|
|
|
errors = [] |
|
|
|
conversions = {} |
|
|
|
for literal in converter.__args__: |
|
|
|
literal_args = converter.__args__ |
|
|
|
for literal in literal_args: |
|
|
|
literal_type = type(literal) |
|
|
|
try: |
|
|
|
value = conversions[literal_type] |
|
|
@ -591,7 +601,7 @@ class Command(_BaseCommand): |
|
|
|
return value |
|
|
|
|
|
|
|
# if we're here, then we failed to match all the literals |
|
|
|
raise BadLiteralArgument(param, converter.__args__, errors) |
|
|
|
raise BadLiteralArgument(param, literal_args, errors) |
|
|
|
|
|
|
|
return await self._actual_conversion(ctx, converter, argument, param) |
|
|
|
|
|
|
@ -614,7 +624,7 @@ class Command(_BaseCommand): |
|
|
|
# The greedy converter is simple -- it keeps going until it fails in which case, |
|
|
|
# it undos the view ready for the next parameter to use instead |
|
|
|
if isinstance(converter, converters.Greedy): |
|
|
|
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: |
|
|
|
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): |
|
|
|
return await self._transform_greedy_pos(ctx, param, required, converter.converter) |
|
|
|
elif param.kind == param.VAR_POSITIONAL: |
|
|
|
return await self._transform_greedy_var_pos(ctx, param, converter.converter) |
|
|
@ -782,7 +792,7 @@ class Command(_BaseCommand): |
|
|
|
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') |
|
|
|
|
|
|
|
for name, param in iterator: |
|
|
|
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY: |
|
|
|
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): |
|
|
|
transformed = await self.transform(ctx, param) |
|
|
|
args.append(transformed) |
|
|
|
elif param.kind == param.KEYWORD_ONLY: |
|
|
@ -1074,7 +1084,7 @@ class Command(_BaseCommand): |
|
|
|
return '' |
|
|
|
|
|
|
|
def _is_typing_optional(self, annotation): |
|
|
|
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None) |
|
|
|
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ |
|
|
|
|
|
|
|
@property |
|
|
|
def signature(self): |
|
|
@ -1094,13 +1104,14 @@ class Command(_BaseCommand): |
|
|
|
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the |
|
|
|
# parameter signature is a literal list of it's values |
|
|
|
annotation = param.annotation.converter if greedy else param.annotation |
|
|
|
origin = get_typing_origin(annotation) |
|
|
|
origin = getattr(annotation, '__origin__', None) |
|
|
|
if not greedy and origin is Union: |
|
|
|
union_args = get_typing_args(annotation) |
|
|
|
optional = union_args[-1] is type(None) |
|
|
|
if optional: |
|
|
|
none_cls = type(None) |
|
|
|
union_args = annotation.__args__ |
|
|
|
optional = union_args[-1] is none_cls |
|
|
|
if len(union_args) == 2 and optional: |
|
|
|
annotation = union_args[0] |
|
|
|
origin = get_typing_origin(annotation) |
|
|
|
origin = getattr(annotation, '__origin__', None) |
|
|
|
|
|
|
|
if origin is Literal: |
|
|
|
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) |
|
|
|