From fafc5b13f6677f57fa3f5b64cb873b3a4cf1fb4c Mon Sep 17 00:00:00 2001 From: Josh Date: Sat, 19 Mar 2022 20:34:19 +1000 Subject: [PATCH] [commands] Rework help command to avoid a deepcopy on invoke --- discord/ext/commands/context.py | 5 +- discord/ext/commands/help.py | 287 +++++++++++--------------------- 2 files changed, 99 insertions(+), 193 deletions(-) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 3297c056f..c9ca232f4 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -354,6 +354,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): """ from .core import Group, Command, wrap_callback from .errors import CommandError + from .help import _context bot = self.bot cmd = bot.help_command @@ -361,8 +362,8 @@ class Context(discord.abc.Messageable, Generic[BotT]): if cmd is None: return None - cmd = cmd.copy() - cmd.context = self # type: ignore + _context.set(self) + if len(args) == 0: await cmd.prepare_help_command(self, None) mapping = cmd.get_bot_mapping() diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index 259cf2f9f..afbdb1d43 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations +from contextvars import ContextVar import itertools -import copy import functools import re @@ -33,12 +33,12 @@ from typing import ( TYPE_CHECKING, Optional, Generator, + Generic, List, TypeVar, Callable, Any, Dict, - Tuple, Iterable, Sequence, Mapping, @@ -50,7 +50,6 @@ from .core import Group, Command, get_signature_parameters from .errors import CommandError if TYPE_CHECKING: - from typing_extensions import Self import inspect import discord.abc @@ -59,13 +58,6 @@ if TYPE_CHECKING: from .context import Context from .cog import Cog - from ._types import ( - Check, - ContextT, - BotT, - _Bot, - ) - __all__ = ( 'Paginator', 'HelpCommand', @@ -73,7 +65,11 @@ __all__ = ( 'MinimalHelpCommand', ) +T = TypeVar('T') + +ContextT = TypeVar('ContextT', bound='Context') FuncT = TypeVar('FuncT', bound=Callable[..., Any]) +HelpCommandCommand = Command[Optional['Cog'], ... if TYPE_CHECKING else Any, Any] MISSING: Any = discord.utils.MISSING @@ -219,92 +215,12 @@ def _not_overridden(f: FuncT) -> FuncT: return f -class _HelpCommandImpl(Command): - def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None: - super().__init__(inject.command_callback, *args, **kwargs) - self._original: HelpCommand = inject - self._injected: HelpCommand = inject - self.params: Dict[str, inspect.Parameter] = get_signature_parameters( - inject.command_callback, globals(), skip_parameters=1 - ) - - async def prepare(self, ctx: Context[Any]) -> None: - self._injected = injected = self._original.copy() - injected.context = ctx - self.callback = injected.command_callback - self.params = get_signature_parameters(injected.command_callback, globals(), skip_parameters=1) - - on_error = injected.on_help_command_error - if not hasattr(on_error, '__help_command_not_overridden__'): - if self.cog is not None: - self.on_error = self._on_error_cog_implementation - else: - self.on_error = on_error - - await super().prepare(ctx) - - async def _parse_arguments(self, ctx: Context[BotT]) -> None: - # Make the parser think we don't have a cog so it doesn't - # inject the parameter into `ctx.args`. - original_cog = self.cog - self.cog = None - try: - await super()._parse_arguments(ctx) - finally: - self.cog = original_cog - - async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None: - await self._injected.on_help_command_error(ctx, error) - - def _inject_into_cog(self, cog: Cog) -> None: - # Warning: hacky - - # Make the cog think that get_commands returns this command - # as well if we inject it without modifying __cog_commands__ - # since that's used for the injection and ejection of cogs. - def wrapped_get_commands( - *, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands - ) -> List[Command[Any, ..., Any]]: - ret = _original() - ret.append(self) - return ret - - # Ditto here - def wrapped_walk_commands( - *, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands - ): - yield from _original() - yield self - - functools.update_wrapper(wrapped_get_commands, cog.get_commands) - functools.update_wrapper(wrapped_walk_commands, cog.walk_commands) - cog.get_commands = wrapped_get_commands - cog.walk_commands = wrapped_walk_commands - self.cog = cog - - def _eject_cog(self) -> None: - if self.cog is None: - return - - # revert back into their original methods - cog = self.cog - cog.get_commands = cog.get_commands.__wrapped__ - cog.walk_commands = cog.walk_commands.__wrapped__ - self.cog = None +_context: ContextVar[Optional[Context]] = ContextVar('context', default=None) -class HelpCommand: +class HelpCommand(HelpCommandCommand, Generic[ContextT]): r"""The base implementation for help command formatting. - .. note:: - - Internally instances of this class are deep copied every time - the command itself is invoked to prevent a race condition - mentioned in :issue:`2123`. - - This means that relying on the state of this class to be - the same between command invocations would not work as expected. - Attributes ------------ context: Optional[:class:`Context`] @@ -336,88 +252,53 @@ class HelpCommand: MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) - if TYPE_CHECKING: - __original_kwargs__: Dict[str, Any] - __original_args__: Tuple[Any, ...] - - def __new__(cls, *args: Any, **kwargs: Any) -> Self: - # To prevent race conditions of a single instance while also allowing - # for settings to be passed the original arguments passed must be assigned - # to allow for easier copies (which will be made when the help command is actually called) - # see issue 2123 - self = super().__new__(cls) - - # Shallow copies cannot be used in this case since it is not unusual to pass - # instances that need state, e.g. Paginator or what have you into the function - # The keys can be safely copied as-is since they're 99.99% certain of being - # string keys - deepcopy = copy.deepcopy - self.__original_kwargs__ = {k: deepcopy(v) for k, v in kwargs.items()} - self.__original_args__ = deepcopy(args) - return self - - def __init__(self, **options: Any) -> None: - self.show_hidden: bool = options.pop('show_hidden', False) - self.verify_checks: bool = options.pop('verify_checks', True) - self.command_attrs: Dict[str, Any] - self.command_attrs = attrs = options.pop('command_attrs', {}) + def __init__( + self, + *, + show_hidden: bool = False, + verify_checks: bool = True, + command_attrs: Dict[str, Any] = MISSING, + ) -> None: + self.show_hidden: bool = show_hidden + self.verify_checks: bool = verify_checks + self.command_attrs = attrs = command_attrs if command_attrs is not MISSING else {} attrs.setdefault('name', 'help') attrs.setdefault('help', 'Shows this message') - self.context: Context[_Bot] = MISSING - self._command_impl = _HelpCommandImpl(self, **self.command_attrs) - - def copy(self) -> Self: - obj = self.__class__(*self.__original_args__, **self.__original_kwargs__) - obj._command_impl = self._command_impl - return obj - - def _add_to_bot(self, bot: BotBase) -> None: - command = _HelpCommandImpl(self, **self.command_attrs) - bot.add_command(command) - self._command_impl = command - - def _remove_from_bot(self, bot: BotBase) -> None: - bot.remove_command(self._command_impl.name) - self._command_impl._eject_cog() - - def add_check(self, func: Check[ContextT], /) -> None: - """ - Adds a check to the help command. - - .. versionadded:: 1.4 - - .. versionchanged:: 2.0 - - ``func`` parameter is now positional-only. - - Parameters - ---------- - func - The function that will be used as a check. - """ - - self._command_impl.add_check(func) - - def remove_check(self, func: Check[ContextT], /) -> None: - """ - Removes a check from the help command. - - This function is idempotent and will not raise an exception if - the function is not in the command's checks. + self._cog: Optional[Cog] = None + super().__init__(self._set_context, **attrs) + self.params: Dict[str, inspect.Parameter] = get_signature_parameters( + self.command_callback, globals(), skip_parameters=1 + ) - .. versionadded:: 1.4 + async def __call__(self, context: ContextT, *args: Any, **kwargs: Any) -> Any: + return await self.command_callback(context, *args, **kwargs) - .. versionchanged:: 2.0 + async def _set_context(self, context: ContextT, *args: Any, **kwargs: Any) -> Any: + _context.set(context) + return await self.command_callback(context, *args, **kwargs) - ``func`` parameter is now positional-only. + @property + def context(self) -> ContextT: + ctx = _context.get() + if ctx is None: + raise AttributeError('context attribute cannot be accessed in non command-invocation contexts.') + return ctx # type: ignore - Parameters - ---------- - func - The function to remove from the checks. - """ + def _add_to_bot(self, bot: BotBase) -> None: + bot.add_command(self) # type: ignore - self._command_impl.remove_check(func) + def _remove_from_bot(self, bot: BotBase) -> None: + bot.remove_command(self) # type: ignore + self._eject_cog() + + async def invoke(self, ctx: ContextT) -> None: + # we need to temporarily set the cog to None to prevent the cog + # from being passed into the command callback. + cog = self._cog + self._cog = None + await self.prepare(ctx) + self._cog = cog + await self.callback(*ctx.args, **ctx.kwargs) def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]: """Retrieves the bot mapping passed to :meth:`send_bot_help`.""" @@ -441,7 +322,7 @@ class HelpCommand: Optional[:class:`str`] The command name that triggered this invocation. """ - command_name = self._command_impl.name + command_name = self.name ctx = self.context if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name: return command_name @@ -498,31 +379,54 @@ class HelpCommand: return self.MENTION_PATTERN.sub(replace, string) - @property - def cog(self) -> Optional[Cog]: - """A property for retrieving or setting the cog for the help command. + def _inject_into_cog(self, cog: Cog) -> None: + # Warning: hacky + + # Make the cog think that get_commands returns this command + # as well if we inject it without modifying __cog_commands__ + # since that's used for the injection and ejection of cogs. + def wrapped_get_commands( + *, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands + ) -> List[Command[Any, ..., Any]]: + ret = _original() + ret.append(self) + return ret + + # Ditto here + def wrapped_walk_commands( + *, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands + ): + yield from _original() + yield self + + functools.update_wrapper(wrapped_get_commands, cog.get_commands) + functools.update_wrapper(wrapped_walk_commands, cog.walk_commands) + cog.get_commands = wrapped_get_commands + cog.walk_commands = wrapped_walk_commands + self._cog = cog - When a cog is set for the help command, it is as-if the help command - belongs to that cog. All cog special methods will apply to the help - command and it will be automatically unset on unload. + def _eject_cog(self) -> None: + if self._cog is None: + return - To unbind the cog from the help command, you can set it to ``None``. + # revert back into their original methods + cog = self._cog + cog.get_commands = cog.get_commands.__wrapped__ + cog.walk_commands = cog.walk_commands.__wrapped__ + self._cog = None - Returns - -------- - Optional[:class:`Cog`] - The cog that is currently set for the help command. - """ - return self._command_impl.cog + @property + def cog(self) -> Optional[Cog]: + return self._cog @cog.setter def cog(self, cog: Optional[Cog]) -> None: # Remove whatever cog is currently valid, if any - self._command_impl._eject_cog() + self._eject_cog() # If a new cog is set then inject it. if cog is not None: - self._command_impl._inject_into_cog(cog) + self._inject_into_cog(cog) def command_not_found(self, string: str) -> str: """|maybecoro| @@ -693,7 +597,7 @@ class HelpCommand: await destination.send(error) @_not_overridden - async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None: + async def on_help_command_error(self, ctx: ContextT, error: CommandError) -> None: """|coro| The help command's error handler, as specified by :ref:`ext_commands_error_handler`. @@ -836,7 +740,7 @@ class HelpCommand: """ return None - async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None: + async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None) -> None: """|coro| A low level method that can be used to prepare the help command @@ -860,7 +764,7 @@ class HelpCommand: """ pass - async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None: + async def command_callback(self, ctx: ContextT, *, command: Optional[str] = None) -> Any: """|coro| The actual implementation of the help command. @@ -880,6 +784,7 @@ class HelpCommand: - :meth:`prepare_help_command` """ await self.prepare_help_command(ctx, command) + bot = ctx.bot if command is None: @@ -905,7 +810,7 @@ class HelpCommand: for key in keys[1:]: try: - found = cmd.all_commands.get(key) # type: ignore + found = cmd.all_commands.get(key) except AttributeError: string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) @@ -921,7 +826,7 @@ class HelpCommand: return await self.send_command_help(cmd) -class DefaultHelpCommand(HelpCommand): +class DefaultHelpCommand(HelpCommand[ContextT]): """The implementation of the default help command. This inherits from :class:`HelpCommand`. @@ -1062,7 +967,7 @@ class DefaultHelpCommand(HelpCommand): else: return ctx.channel - async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: + async def prepare_help_command(self, ctx: ContextT, command: str) -> None: self.paginator.clear() await super().prepare_help_command(ctx, command) @@ -1130,7 +1035,7 @@ class DefaultHelpCommand(HelpCommand): await self.send_pages() -class MinimalHelpCommand(HelpCommand): +class MinimalHelpCommand(HelpCommand[ContextT]): """An implementation of a help command with minimal output. This inherits from :class:`HelpCommand`. @@ -1306,7 +1211,7 @@ class MinimalHelpCommand(HelpCommand): else: return ctx.channel - async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: + async def prepare_help_command(self, ctx: ContextT, command: str) -> None: self.paginator.clear() await super().prepare_help_command(ctx, command)