Browse Source

[commands] Properly support commands.param in hybrid commands

pull/7881/head
Rapptz 3 years ago
parent
commit
f072edfdfc
  1. 61
      discord/ext/commands/hybrid.py

61
discord/ext/commands/hybrid.py

@ -42,8 +42,9 @@ import inspect
from discord import app_commands from discord import app_commands
from discord.utils import MISSING, maybe_coroutine, async_all from discord.utils import MISSING, maybe_coroutine, async_all
from .core import Command, Group from .core import Command, Group
from .errors import CommandRegistrationError, CommandError, HybridCommandError, ConversionError from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError
from .converter import Converter from .converter import Converter
from .parameters import Parameter
from .cog import Cog from .cog import Cog
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,7 +52,6 @@ if TYPE_CHECKING:
from ._types import ContextT, Coro, BotT from ._types import ContextT, Coro, BotT
from .bot import Bot from .bot import Bot
from .context import Context from .context import Context
from .parameters import Parameter
from discord.app_commands.commands import ( from discord.app_commands.commands import (
Check as AppCommandCheck, Check as AppCommandCheck,
AutocompleteCallback, AutocompleteCallback,
@ -71,6 +71,7 @@ CogT = TypeVar('CogT', bound='Cog')
CommandT = TypeVar('CommandT', bound='Command') CommandT = TypeVar('CommandT', bound='Command')
# CHT = TypeVar('CHT', bound='Check') # CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group') GroupT = TypeVar('GroupT', bound='Group')
_NoneType = type(None)
if TYPE_CHECKING: if TYPE_CHECKING:
P = ParamSpec('P') P = ParamSpec('P')
@ -85,6 +86,17 @@ else:
P2 = TypeVar('P2') P2 = TypeVar('P2')
class _CallableDefault:
__slots__ = ('func',)
def __init__(self, func: Callable[[Context], Any]) -> None:
self.func: Callable[[Context], Any] = func
@property
def __class__(self) -> Any:
return _NoneType
def is_converter(converter: Any) -> bool: def is_converter(converter: Any) -> bool:
return (inspect.isclass(converter) and issubclass(converter, Converter)) or isinstance(converter, Converter) return (inspect.isclass(converter) and issubclass(converter, Converter)) or isinstance(converter, Converter)
@ -107,12 +119,33 @@ def make_converter_transformer(converter: Any) -> Type[app_commands.Transformer]
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)}) return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]:
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
try:
return func(value)
except CommandError:
raise
except Exception as exc:
raise BadArgument(f'Converting to "{func.__name__}" failed') from exc
return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]: def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]:
# Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly # Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly
params = signature.parameters.copy() params = signature.parameters.copy()
for name, parameter in parameters.items(): for name, parameter in parameters.items():
if is_converter(parameter.converter) and not hasattr(parameter.converter, '__discord_app_commands_transformer__'): is_transformer = hasattr(parameter.converter, '__discord_app_commands_transformer__')
if is_converter(parameter.converter) and not is_transformer:
params[name] = params[name].replace(annotation=make_converter_transformer(parameter.converter)) params[name] = params[name].replace(annotation=make_converter_transformer(parameter.converter))
if callable(parameter.converter) and not inspect.isclass(parameter.converter) and not is_transformer:
params[name] = params[name].replace(annotation=make_callable_transformer(parameter.converter))
if callable(parameter.default):
params[name] = params[name].replace(default=_CallableDefault(parameter.default))
if isinstance(params[name].default, Parameter):
# If we're here, then then it hasn't been handled yet so it should be removed completely
params[name] = params[name].replace(default=parameter.empty)
return list(params.values()) return list(params.values())
@ -146,6 +179,28 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
} }
return self._copy_with(parent=self.parent, binding=self.binding, bindings=bindings) return self._copy_with(parent=self.parent, binding=self.binding, bindings=bindings)
async def _transform_arguments(
self, interaction: discord.Interaction, namespace: app_commands.Namespace
) -> Dict[str, Any]:
values = namespace.__dict__
transformed_values = {}
for param in self._params.values():
try:
value = values[param.display_name]
except KeyError:
if not param.required:
if isinstance(param.default, _CallableDefault):
transformed_values[param.name] = await maybe_coroutine(param.default.func, interaction._baton)
else:
transformed_values[param.name] = param.default
else:
raise app_commands.CommandSignatureMismatch(self) from None
else:
transformed_values[param.name] = await param.transform(interaction, value)
return transformed_values
async def _check_can_run(self, interaction: discord.Interaction) -> bool: async def _check_can_run(self, interaction: discord.Interaction) -> bool:
# Hybrid checks must run like so: # Hybrid checks must run like so:
# - Bot global check once # - Bot global check once

Loading…
Cancel
Save