@ -22,10 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE .
"""
from typing import (
Any ,
Dict ,
ForwardRef ,
Iterable ,
Literal ,
Tuple ,
Union ,
get_args as get_typing_args ,
get_origin as get_typing_origin ,
)
import asyncio
import functools
import inspect
import typing
import datetime
import sys
@ -64,6 +74,83 @@ __all__ = (
' bot_has_guild_permissions '
)
PY_310 = sys . version_info > = ( 3 , 10 )
def flatten_literal_params ( parameters : Iterable [ Any ] ) - > Tuple [ Any , . . . ] :
params = [ ]
literal_cls = type ( Literal [ 0 ] )
for p in parameters :
if isinstance ( p , literal_cls ) :
params . extend ( p . __args__ )
else :
params . append ( p )
return tuple ( params )
def _evaluate_annotation ( tp : Any , globals : Dict [ str , Any ] , cache : Dict [ str , Any ] = { } , * , implicit_str = True ) :
if isinstance ( tp , ForwardRef ) :
tp = tp . __forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance ( tp , str ) :
if tp in cache :
return cache [ tp ]
evaluated = eval ( tp , globals )
cache [ tp ] = evaluated
return _evaluate_annotation ( evaluated , globals , cache )
if hasattr ( tp , ' __args__ ' ) :
implicit_str = True
args = tp . __args__
if tp . __origin__ is Literal :
if not PY_310 :
args = flatten_literal_params ( tp . __args__ )
implicit_str = False
evaluated_args = tuple (
_evaluate_annotation ( arg , globals , cache , implicit_str = implicit_str ) for arg in args
)
if evaluated_args == args :
return tp
try :
return tp . copy_with ( evaluated_args )
except AttributeError :
return tp . __origin__ [ evaluated_args ]
return tp
def resolve_annotation ( annotation : Any , globalns : Dict [ str , Any ] , cache : Dict [ str , Any ] = { } ) - > Any :
if annotation is None :
return type ( None )
if isinstance ( annotation , str ) :
annotation = ForwardRef ( annotation )
return _evaluate_annotation ( annotation , globalns , cache )
def get_signature_parameters ( function ) - > Dict [ str , inspect . Parameter ] :
globalns = function . __globals__
signature = inspect . signature ( function )
params = { }
cache : Dict [ str , Any ] = { }
for name , parameter in signature . parameters . items ( ) :
annotation = parameter . annotation
if annotation is parameter . empty :
params [ name ] = parameter
continue
if annotation is None :
params [ name ] = parameter . replace ( annotation = type ( None ) )
continue
annotation = _evaluate_annotation ( annotation , globalns , cache )
if annotation is converters . Greedy :
raise TypeError ( ' Unparameterized Greedy[...] is disallowed in signature. ' )
params [ name ] = parameter . replace ( annotation = annotation )
return params
def wrap_callback ( coro ) :
@functools . wraps ( coro )
async def wrapped ( * args , * * kwargs ) :
@ -300,40 +387,7 @@ class Command(_BaseCommand):
def callback ( self , function ) :
self . _callback = function
self . module = function . __module__
signature = inspect . signature ( function )
self . params = signature . parameters . copy ( )
# see: https://bugs.python.org/issue41341
resolve = self . _recursive_resolve if sys . version_info < ( 3 , 9 ) else self . _return_resolved
try :
type_hints = { k : resolve ( v ) for k , v in typing . get_type_hints ( function ) . items ( ) }
except NameError as e :
raise NameError ( f ' unresolved forward reference: { e . args [ 0 ] } ' ) from None
for key , value in self . params . items ( ) :
# coalesce the forward references
if key in type_hints :
self . params [ key ] = value = value . replace ( annotation = type_hints [ key ] )
# fail early for when someone passes an unparameterized Greedy type
if value . annotation is converters . Greedy :
raise TypeError ( ' Unparameterized Greedy[...] is disallowed in signature. ' )
def _return_resolved ( self , type , * * kwargs ) :
return type
def _recursive_resolve ( self , type , * , globals = None ) :
if not isinstance ( type , typing . ForwardRef ) :
return type
resolved = eval ( type . __forward_arg__ , globals )
args = typing . get_args ( resolved )
for index , arg in enumerate ( args ) :
inner_resolve_result = self . _recursive_resolve ( arg , globals = globals )
resolved [ index ] = inner_resolve_result
return resolved
self . params = get_signature_parameters ( function )
def add_check ( self , func ) :
""" Adds a check to the command.
@ -493,12 +547,12 @@ class Command(_BaseCommand):
raise BadArgument ( f ' Converting to " { name } " failed for parameter " { param . name } " . ' ) from exc
async def do_conversion ( self , ctx , converter , argument , param ) :
origin = typing . get _origin( converter )
origin = get_ typing_origin( converter )
if origin is typing . Union :
if origin is Union :
errors = [ ]
_NoneType = type ( None )
for conv in typing . get _args( converter ) :
for conv in get_ typing_args( converter ) :
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
@ -514,13 +568,12 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed all the converters
raise BadUnionArgument ( param , typing . get _args( converter ) , errors )
raise BadUnionArgument ( param , get_ typing_args( converter ) , errors )
if origin is typing . Literal :
if origin is Literal :
errors = [ ]
conversions = { }
literal_args = tuple ( self . _flattened_typing_literal_args ( converter ) )
for literal in literal_args :
for literal in converter . __args__ :
literal_type = type ( literal )
try :
value = conversions [ literal_type ]
@ -538,7 +591,7 @@ class Command(_BaseCommand):
return value
# if we're here, then we failed to match all the literals
raise BadLiteralArgument ( param , literal_args , errors )
raise BadLiteralArgument ( param , converter . __args__ , errors )
return await self . _actual_conversion ( ctx , converter , argument , param )
@ -1021,14 +1074,7 @@ class Command(_BaseCommand):
return ' '
def _is_typing_optional ( self , annotation ) :
return typing . get_origin ( annotation ) is typing . Union and typing . get_args ( annotation ) [ - 1 ] is type ( None )
def _flattened_typing_literal_args ( self , annotation ) :
for literal in typing . get_args ( annotation ) :
if typing . get_origin ( literal ) is typing . Literal :
yield from self . _flattened_typing_literal_args ( literal )
else :
yield literal
return get_typing_origin ( annotation ) is Union and get_typing_args ( annotation ) [ - 1 ] is type ( None )
@property
def signature ( self ) :
@ -1048,17 +1094,16 @@ class Command(_BaseCommand):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param . annotation . converter if greedy else param . annotation
origin = typing . get _origin( annotation )
if not greedy and origin is typing . Union :
union_args = typing . get _args( annotation )
origin = get_ typing_origin( annotation )
if not greedy and origin is Union :
union_args = get_ typing_args( annotation )
optional = union_args [ - 1 ] is type ( None )
if optional :
annotation = union_args [ 0 ]
origin = typing . get _origin( annotation )
origin = get_ typing_origin( annotation )
if origin is typing . Literal :
name = ' | ' . join ( f ' " { v } " ' if isinstance ( v , str ) else str ( v )
for v in self . _flattened_typing_literal_args ( annotation ) )
if origin is Literal :
name = ' | ' . join ( f ' " { v } " ' if isinstance ( v , str ) else str ( v ) for v in annotation . __args__ )
if param . default is not param . empty :
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.