Browse Source

[commands] Rework help command to avoid a deepcopy on invoke

pull/7723/head
Josh 3 years ago
committed by GitHub
parent
commit
fafc5b13f6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      discord/ext/commands/context.py
  2. 287
      discord/ext/commands/help.py

5
discord/ext/commands/context.py

@ -354,6 +354,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
""" """
from .core import Group, Command, wrap_callback from .core import Group, Command, wrap_callback
from .errors import CommandError from .errors import CommandError
from .help import _context
bot = self.bot bot = self.bot
cmd = bot.help_command cmd = bot.help_command
@ -361,8 +362,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if cmd is None: if cmd is None:
return None return None
cmd = cmd.copy() _context.set(self)
cmd.context = self # type: ignore
if len(args) == 0: if len(args) == 0:
await cmd.prepare_help_command(self, None) await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping() mapping = cmd.get_bot_mapping()

287
discord/ext/commands/help.py

@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar
import itertools import itertools
import copy
import functools import functools
import re import re
@ -33,12 +33,12 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Optional, Optional,
Generator, Generator,
Generic,
List, List,
TypeVar, TypeVar,
Callable, Callable,
Any, Any,
Dict, Dict,
Tuple,
Iterable, Iterable,
Sequence, Sequence,
Mapping, Mapping,
@ -50,7 +50,6 @@ from .core import Group, Command, get_signature_parameters
from .errors import CommandError from .errors import CommandError
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
import inspect import inspect
import discord.abc import discord.abc
@ -59,13 +58,6 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
from .cog import Cog from .cog import Cog
from ._types import (
Check,
ContextT,
BotT,
_Bot,
)
__all__ = ( __all__ = (
'Paginator', 'Paginator',
'HelpCommand', 'HelpCommand',
@ -73,7 +65,11 @@ __all__ = (
'MinimalHelpCommand', 'MinimalHelpCommand',
) )
T = TypeVar('T')
ContextT = TypeVar('ContextT', bound='Context')
FuncT = TypeVar('FuncT', bound=Callable[..., Any]) FuncT = TypeVar('FuncT', bound=Callable[..., Any])
HelpCommandCommand = Command[Optional['Cog'], ... if TYPE_CHECKING else Any, Any]
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
@ -219,92 +215,12 @@ def _not_overridden(f: FuncT) -> FuncT:
return f return f
class _HelpCommandImpl(Command): _context: ContextVar[Optional[Context]] = ContextVar('context', default=None)
def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None:
super().__init__(inject.command_callback, *args, **kwargs)
self._original: HelpCommand = inject
self._injected: HelpCommand = inject
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
inject.command_callback, globals(), skip_parameters=1
)
async def prepare(self, ctx: Context[Any]) -> None:
self._injected = injected = self._original.copy()
injected.context = ctx
self.callback = injected.command_callback
self.params = get_signature_parameters(injected.command_callback, globals(), skip_parameters=1)
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overridden__'):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
self.on_error = on_error
await super().prepare(ctx)
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
# Make the parser think we don't have a cog so it doesn't
# inject the parameter into `ctx.args`.
original_cog = self.cog
self.cog = None
try:
await super()._parse_arguments(ctx)
finally:
self.cog = original_cog
async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None:
await self._injected.on_help_command_error(ctx, error)
def _inject_into_cog(self, cog: Cog) -> None:
# Warning: hacky
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
functools.update_wrapper(wrapped_get_commands, cog.get_commands)
functools.update_wrapper(wrapped_walk_commands, cog.walk_commands)
cog.get_commands = wrapped_get_commands
cog.walk_commands = wrapped_walk_commands
self.cog = cog
def _eject_cog(self) -> None:
if self.cog is None:
return
# revert back into their original methods
cog = self.cog
cog.get_commands = cog.get_commands.__wrapped__
cog.walk_commands = cog.walk_commands.__wrapped__
self.cog = None
class HelpCommand: class HelpCommand(HelpCommandCommand, Generic[ContextT]):
r"""The base implementation for help command formatting. r"""The base implementation for help command formatting.
.. note::
Internally instances of this class are deep copied every time
the command itself is invoked to prevent a race condition
mentioned in :issue:`2123`.
This means that relying on the state of this class to be
the same between command invocations would not work as expected.
Attributes Attributes
------------ ------------
context: Optional[:class:`Context`] context: Optional[:class:`Context`]
@ -336,88 +252,53 @@ class HelpCommand:
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
if TYPE_CHECKING: def __init__(
__original_kwargs__: Dict[str, Any] self,
__original_args__: Tuple[Any, ...] *,
show_hidden: bool = False,
def __new__(cls, *args: Any, **kwargs: Any) -> Self: verify_checks: bool = True,
# To prevent race conditions of a single instance while also allowing command_attrs: Dict[str, Any] = MISSING,
# for settings to be passed the original arguments passed must be assigned ) -> None:
# to allow for easier copies (which will be made when the help command is actually called) self.show_hidden: bool = show_hidden
# see issue 2123 self.verify_checks: bool = verify_checks
self = super().__new__(cls) self.command_attrs = attrs = command_attrs if command_attrs is not MISSING else {}
# Shallow copies cannot be used in this case since it is not unusual to pass
# instances that need state, e.g. Paginator or what have you into the function
# The keys can be safely copied as-is since they're 99.99% certain of being
# string keys
deepcopy = copy.deepcopy
self.__original_kwargs__ = {k: deepcopy(v) for k, v in kwargs.items()}
self.__original_args__ = deepcopy(args)
return self
def __init__(self, **options: Any) -> None:
self.show_hidden: bool = options.pop('show_hidden', False)
self.verify_checks: bool = options.pop('verify_checks', True)
self.command_attrs: Dict[str, Any]
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help') attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message') attrs.setdefault('help', 'Shows this message')
self.context: Context[_Bot] = MISSING self._cog: Optional[Cog] = None
self._command_impl = _HelpCommandImpl(self, **self.command_attrs) super().__init__(self._set_context, **attrs)
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(
def copy(self) -> Self: self.command_callback, globals(), skip_parameters=1
obj = self.__class__(*self.__original_args__, **self.__original_kwargs__) )
obj._command_impl = self._command_impl
return obj
def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs)
bot.add_command(command)
self._command_impl = command
def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name)
self._command_impl._eject_cog()
def add_check(self, func: Check[ContextT], /) -> None:
"""
Adds a check to the help command.
.. versionadded:: 1.4
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters
----------
func
The function that will be used as a check.
"""
self._command_impl.add_check(func)
def remove_check(self, func: Check[ContextT], /) -> None:
"""
Removes a check from the help command.
This function is idempotent and will not raise an exception if
the function is not in the command's checks.
.. versionadded:: 1.4 async def __call__(self, context: ContextT, *args: Any, **kwargs: Any) -> Any:
return await self.command_callback(context, *args, **kwargs)
.. versionchanged:: 2.0 async def _set_context(self, context: ContextT, *args: Any, **kwargs: Any) -> Any:
_context.set(context)
return await self.command_callback(context, *args, **kwargs)
``func`` parameter is now positional-only. @property
def context(self) -> ContextT:
ctx = _context.get()
if ctx is None:
raise AttributeError('context attribute cannot be accessed in non command-invocation contexts.')
return ctx # type: ignore
Parameters def _add_to_bot(self, bot: BotBase) -> None:
---------- bot.add_command(self) # type: ignore
func
The function to remove from the checks.
"""
self._command_impl.remove_check(func) def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self) # type: ignore
self._eject_cog()
async def invoke(self, ctx: ContextT) -> None:
# we need to temporarily set the cog to None to prevent the cog
# from being passed into the command callback.
cog = self._cog
self._cog = None
await self.prepare(ctx)
self._cog = cog
await self.callback(*ctx.args, **ctx.kwargs)
def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]: def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]:
"""Retrieves the bot mapping passed to :meth:`send_bot_help`.""" """Retrieves the bot mapping passed to :meth:`send_bot_help`."""
@ -441,7 +322,7 @@ class HelpCommand:
Optional[:class:`str`] Optional[:class:`str`]
The command name that triggered this invocation. The command name that triggered this invocation.
""" """
command_name = self._command_impl.name command_name = self.name
ctx = self.context ctx = self.context
if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name: if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name:
return command_name return command_name
@ -498,31 +379,54 @@ class HelpCommand:
return self.MENTION_PATTERN.sub(replace, string) return self.MENTION_PATTERN.sub(replace, string)
@property def _inject_into_cog(self, cog: Cog) -> None:
def cog(self) -> Optional[Cog]: # Warning: hacky
"""A property for retrieving or setting the cog for the help command.
# Make the cog think that get_commands returns this command
# as well if we inject it without modifying __cog_commands__
# since that's used for the injection and ejection of cogs.
def wrapped_get_commands(
*, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands
) -> List[Command[Any, ..., Any]]:
ret = _original()
ret.append(self)
return ret
# Ditto here
def wrapped_walk_commands(
*, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands
):
yield from _original()
yield self
functools.update_wrapper(wrapped_get_commands, cog.get_commands)
functools.update_wrapper(wrapped_walk_commands, cog.walk_commands)
cog.get_commands = wrapped_get_commands
cog.walk_commands = wrapped_walk_commands
self._cog = cog
When a cog is set for the help command, it is as-if the help command def _eject_cog(self) -> None:
belongs to that cog. All cog special methods will apply to the help if self._cog is None:
command and it will be automatically unset on unload. return
To unbind the cog from the help command, you can set it to ``None``. # revert back into their original methods
cog = self._cog
cog.get_commands = cog.get_commands.__wrapped__
cog.walk_commands = cog.walk_commands.__wrapped__
self._cog = None
Returns @property
-------- def cog(self) -> Optional[Cog]:
Optional[:class:`Cog`] return self._cog
The cog that is currently set for the help command.
"""
return self._command_impl.cog
@cog.setter @cog.setter
def cog(self, cog: Optional[Cog]) -> None: def cog(self, cog: Optional[Cog]) -> None:
# Remove whatever cog is currently valid, if any # Remove whatever cog is currently valid, if any
self._command_impl._eject_cog() self._eject_cog()
# If a new cog is set then inject it. # If a new cog is set then inject it.
if cog is not None: if cog is not None:
self._command_impl._inject_into_cog(cog) self._inject_into_cog(cog)
def command_not_found(self, string: str) -> str: def command_not_found(self, string: str) -> str:
"""|maybecoro| """|maybecoro|
@ -693,7 +597,7 @@ class HelpCommand:
await destination.send(error) await destination.send(error)
@_not_overridden @_not_overridden
async def on_help_command_error(self, ctx: Context[BotT], error: CommandError) -> None: async def on_help_command_error(self, ctx: ContextT, error: CommandError) -> None:
"""|coro| """|coro|
The help command's error handler, as specified by :ref:`ext_commands_error_handler`. The help command's error handler, as specified by :ref:`ext_commands_error_handler`.
@ -836,7 +740,7 @@ class HelpCommand:
""" """
return None return None
async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None) -> None: async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None) -> None:
"""|coro| """|coro|
A low level method that can be used to prepare the help command A low level method that can be used to prepare the help command
@ -860,7 +764,7 @@ class HelpCommand:
""" """
pass pass
async def command_callback(self, ctx: Context[BotT], *, command: Optional[str] = None) -> None: async def command_callback(self, ctx: ContextT, *, command: Optional[str] = None) -> Any:
"""|coro| """|coro|
The actual implementation of the help command. The actual implementation of the help command.
@ -880,6 +784,7 @@ class HelpCommand:
- :meth:`prepare_help_command` - :meth:`prepare_help_command`
""" """
await self.prepare_help_command(ctx, command) await self.prepare_help_command(ctx, command)
bot = ctx.bot bot = ctx.bot
if command is None: if command is None:
@ -905,7 +810,7 @@ class HelpCommand:
for key in keys[1:]: for key in keys[1:]:
try: try:
found = cmd.all_commands.get(key) # type: ignore found = cmd.all_commands.get(key)
except AttributeError: except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
return await self.send_error_message(string) return await self.send_error_message(string)
@ -921,7 +826,7 @@ class HelpCommand:
return await self.send_command_help(cmd) return await self.send_command_help(cmd)
class DefaultHelpCommand(HelpCommand): class DefaultHelpCommand(HelpCommand[ContextT]):
"""The implementation of the default help command. """The implementation of the default help command.
This inherits from :class:`HelpCommand`. This inherits from :class:`HelpCommand`.
@ -1062,7 +967,7 @@ class DefaultHelpCommand(HelpCommand):
else: else:
return ctx.channel return ctx.channel
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: async def prepare_help_command(self, ctx: ContextT, command: str) -> None:
self.paginator.clear() self.paginator.clear()
await super().prepare_help_command(ctx, command) await super().prepare_help_command(ctx, command)
@ -1130,7 +1035,7 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
class MinimalHelpCommand(HelpCommand): class MinimalHelpCommand(HelpCommand[ContextT]):
"""An implementation of a help command with minimal output. """An implementation of a help command with minimal output.
This inherits from :class:`HelpCommand`. This inherits from :class:`HelpCommand`.
@ -1306,7 +1211,7 @@ class MinimalHelpCommand(HelpCommand):
else: else:
return ctx.channel return ctx.channel
async def prepare_help_command(self, ctx: Context[BotT], command: str) -> None: async def prepare_help_command(self, ctx: ContextT, command: str) -> None:
self.paginator.clear() self.paginator.clear()
await super().prepare_help_command(ctx, command) await super().prepare_help_command(ctx, command)

Loading…
Cancel
Save