Browse Source

[commands] Properly support commands.param in hybrid commands

pull/7881/head
Rapptz 3 years ago
parent
commit
f072edfdfc
  1. 61
      discord/ext/commands/hybrid.py

61
discord/ext/commands/hybrid.py

@ -42,8 +42,9 @@ import inspect
from discord import app_commands
from discord.utils import MISSING, maybe_coroutine, async_all
from .core import Command, Group
from .errors import CommandRegistrationError, CommandError, HybridCommandError, ConversionError
from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError
from .converter import Converter
from .parameters import Parameter
from .cog import Cog
if TYPE_CHECKING:
@ -51,7 +52,6 @@ if TYPE_CHECKING:
from ._types import ContextT, Coro, BotT
from .bot import Bot
from .context import Context
from .parameters import Parameter
from discord.app_commands.commands import (
Check as AppCommandCheck,
AutocompleteCallback,
@ -71,6 +71,7 @@ CogT = TypeVar('CogT', bound='Cog')
CommandT = TypeVar('CommandT', bound='Command')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
_NoneType = type(None)
if TYPE_CHECKING:
P = ParamSpec('P')
@ -85,6 +86,17 @@ else:
P2 = TypeVar('P2')
class _CallableDefault:
__slots__ = ('func',)
def __init__(self, func: Callable[[Context], Any]) -> None:
self.func: Callable[[Context], Any] = func
@property
def __class__(self) -> Any:
return _NoneType
def is_converter(converter: Any) -> bool:
return (inspect.isclass(converter) and issubclass(converter, Converter)) or isinstance(converter, Converter)
@ -107,12 +119,33 @@ def make_converter_transformer(converter: Any) -> Type[app_commands.Transformer]
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
try:
return func(value)
except CommandError:
raise
except Exception as exc:
raise BadArgument(f'Converting to "{func.__name__}" failed') from exc
return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]:
# Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly
params = signature.parameters.copy()
for name, parameter in parameters.items():
if is_converter(parameter.converter) and not hasattr(parameter.converter, '__discord_app_commands_transformer__'):
is_transformer = hasattr(parameter.converter, '__discord_app_commands_transformer__')
if is_converter(parameter.converter) and not is_transformer:
params[name] = params[name].replace(annotation=make_converter_transformer(parameter.converter))
if callable(parameter.converter) and not inspect.isclass(parameter.converter) and not is_transformer:
params[name] = params[name].replace(annotation=make_callable_transformer(parameter.converter))
if callable(parameter.default):
params[name] = params[name].replace(default=_CallableDefault(parameter.default))
if isinstance(params[name].default, Parameter):
# If we're here, then then it hasn't been handled yet so it should be removed completely
params[name] = params[name].replace(default=parameter.empty)
return list(params.values())
@ -146,6 +179,28 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
}
return self._copy_with(parent=self.parent, binding=self.binding, bindings=bindings)
async def _transform_arguments(
self, interaction: discord.Interaction, namespace: app_commands.Namespace
) -> Dict[str, Any]:
values = namespace.__dict__
transformed_values = {}
for param in self._params.values():
try:
value = values[param.display_name]
except KeyError:
if not param.required:
if isinstance(param.default, _CallableDefault):
transformed_values[param.name] = await maybe_coroutine(param.default.func, interaction._baton)
else:
transformed_values[param.name] = param.default
else:
raise app_commands.CommandSignatureMismatch(self) from None
else:
transformed_values[param.name] = await param.transform(interaction, value)
return transformed_values
async def _check_can_run(self, interaction: discord.Interaction) -> bool:
# Hybrid checks must run like so:
# - Bot global check once

Loading…
Cancel
Save