From d1039e209e4558f7b600707dc4442a84f86c287c Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 5 Jun 2022 01:18:16 -0400 Subject: [PATCH] Skip default parameter validation when using a transformer Fixes #8110 --- discord/app_commands/transformers.py | 40 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index 3e7d312e9..79b9c0faf 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -613,32 +613,34 @@ def get_supported_annotation( *, _none: type = NoneType, _mapping: Dict[Any, Type[Transformer]] = BUILT_IN_TRANSFORMERS, -) -> Tuple[Any, Any]: +) -> Tuple[Any, Any, bool]: """Returns an appropriate, yet supported, annotation along with an optional default value. + The third boolean element of the tuple indicates if default values should be validated. + This differs from the built in mapping by supporting a few more things. Likewise, this returns a "transformed" annotation that is ready to use with CommandParameter.transform. """ try: - return (_mapping[annotation], MISSING) + return (_mapping[annotation], MISSING, True) except KeyError: pass if hasattr(annotation, '__discord_app_commands_transform__'): - return (annotation.metadata, MISSING) + return (annotation.metadata, MISSING, False) if hasattr(annotation, '__metadata__'): return get_supported_annotation(annotation.__metadata__[0]) if inspect.isclass(annotation): if issubclass(annotation, Transformer): - return (annotation, MISSING) + return (annotation, MISSING, False) if issubclass(annotation, (Enum, InternalEnum)): if all(isinstance(v.value, (str, int, float)) for v in annotation): - return (_make_enum_transformer(annotation), MISSING) + return (_make_enum_transformer(annotation), MISSING, False) else: - return (_make_complex_enum_transformer(annotation), MISSING) + return (_make_complex_enum_transformer(annotation), MISSING, False) if annotation is Choice: raise TypeError(f'Choice requires a type argument of int, str, or float') @@ -646,11 +648,11 @@ def get_supported_annotation( origin = getattr(annotation, '__origin__', None) if origin is Literal: args = annotation.__args__ # type: ignore - return (_make_literal_transformer(args), MISSING) + return (_make_literal_transformer(args), MISSING, True) if origin is Choice: arg = annotation.__args__[0] # type: ignore - return (_make_choice_transformer(arg), MISSING) + return (_make_choice_transformer(arg), MISSING, True) if origin is not Union: # Only Union/Optional is supported right now so bail early @@ -661,10 +663,10 @@ def get_supported_annotation( if args[-1] is _none: if len(args) == 2: underlying = args[0] - inner, _ = get_supported_annotation(underlying) + inner, _, validate_default = get_supported_annotation(underlying) if inner is None: raise TypeError(f'unsupported inner optional type {underlying!r}') - return (inner, None) + return (inner, None, validate_default) else: args = args[:-1] default = None @@ -672,7 +674,7 @@ def get_supported_annotation( # Check for channel union types if any(arg in CHANNEL_TO_TYPES for arg in args): # If any channel type is given, then *all* must be channel types - return (channel_transformer(*args, raw=None), default) + return (channel_transformer(*args, raw=None), default, True) # The only valid transformations here are: # [Member, User] => user @@ -682,9 +684,9 @@ def get_supported_annotation( if not all(arg in supported_types for arg in args): raise TypeError(f'unsupported types given inside {annotation!r}') if args == (User, Member) or args == (Member, User): - return (passthrough_transformer(AppCommandOptionType.user), default) + return (passthrough_transformer(AppCommandOptionType.user), default, True) - return (passthrough_transformer(AppCommandOptionType.mentionable), default) + return (passthrough_transformer(AppCommandOptionType.mentionable), default, True) def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter: @@ -695,7 +697,7 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co of a command parameter. """ - (inner, default) = get_supported_annotation(annotation) + (inner, default, validate_default) = get_supported_annotation(annotation) type = inner.type() if default is MISSING or default is None: @@ -704,12 +706,10 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co default = param_default # Verify validity of the default parameter - if default is not MISSING: - enum_type = getattr(inner, '__discord_app_commands_transformer_enum__', None) - if default.__class__ is not enum_type: - valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) - if not isinstance(default, valid_types): - raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') + if default is not MISSING and validate_default: + valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) + if not isinstance(default, valid_types): + raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') result = CommandParameter( type=type,