From b81fbb5a7f6df4125cddc57feefceef75a1974c1 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Fri, 19 May 2017 21:33:39 -0400 Subject: [PATCH] [commands] Add Context.reinvoke and Command.root_parent Context.reinvoke would be the new way to bypass checks and cooldowns. However, with its addition comes a change in the invocation order of checks, callbacks, and cooldowns. While previously cooldowns would trigger after command argument parsing, the new behaviour parses cooldowns before command argument parsing. The implication of this change is that Context.args and Context.kwargs will no longer be filled properly. --- discord/ext/commands/context.py | 53 +++++++++++++++++++++++ discord/ext/commands/core.py | 77 ++++++++++++++++++++++++++++++++- 2 files changed, 128 insertions(+), 2 deletions(-) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index ef7d9ca68..53aecdb1b 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -121,6 +121,59 @@ class Context(discord.abc.Messageable): ret = yield from command.callback(*arguments, **kwargs) return ret + @asyncio.coroutine + def reinvoke(self, *, call_hooks=False, restart=True): + """|coro| + + Calls the command again. + + This is similar to :meth:`~.Context.invoke` except that it bypasses + checks, cooldowns, and error handlers. + + .. note:: + + If you want to bypass :exc:`.UserInputError` derived exceptions, + it is recommended to use the regular :meth:`~.Context.invoke` + as it will work more naturally. After all, this will end up + using the old arguments the user has used and will thus just + fail again. + + Parameters + ------------ + call_hooks: bool + Whether to call the before and after invoke hooks. + restart: bool + Whether to start the call chain from the very beginning + or where we left off (i.e. the command that caused the error). + """ + cmd = self.command + view = self.view + if cmd is None: + raise ValueError('This context is not valid.') + + # some state to revert to when we're done + index, previous = view.index, view.previous + invoked_with = self.invoked_with + invoked_subcommand = self.invoked_subcommand + subcommand_passed = self.subcommand_passed + + if restart: + to_call = cmd.root_parent or cmd + view.index = len(self.prefix) + 1 + view.previous = 0 + else: + to_call = cmd + + try: + yield from to_call.reinvoke(self, call_hooks=call_hooks) + finally: + self.command = cmd + view.index = index + view.previous = previous + self.invoked_with = invoked_with + self.invoked_subcommand = invoked_subcommand + self.subcommand_passed = subcommand_passed + @property def valid(self): """Checks if the invocation context is valid to be invoked with.""" diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 29a0433c1..2662a31be 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -282,6 +282,24 @@ class Command: return ' '.join(reversed(entries)) + @property + def root_parent(self): + """Retrieves the root parent of this command. + + If the command has no parents then it returns ``None``. + + For example in commands ``?a b c test``, the root parent is + ``a``. + """ + entries = [] + command = self + while command.parent is not None: + command = command.parent + entries.append(command) + entries.append(None) + entries.reverse() + return entries[-1] + @property def qualified_name(self): """Retrieves the fully qualified command name. @@ -350,7 +368,6 @@ class Command: if not view.eof: raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) - @asyncio.coroutine def _verify_checks(self, ctx): if not self.enabled: @@ -407,7 +424,6 @@ class Command: def prepare(self, ctx): ctx.command = self yield from self._verify_checks(ctx) - yield from self._parse_arguments(ctx) if self._buckets.valid: bucket = self._buckets.get_bucket(ctx) @@ -415,6 +431,7 @@ class Command: if retry_after: raise CommandOnCooldown(bucket, retry_after) + yield from self._parse_arguments(ctx) yield from self.call_before_hooks(ctx) def reset_cooldown(self, ctx): @@ -440,6 +457,24 @@ class Command: injected = hooked_wrapped_callback(self, ctx, self.callback) yield from injected(*ctx.args, **ctx.kwargs) + @asyncio.coroutine + def reinvoke(self, ctx, *, call_hooks=False): + ctx.command = self + yield from self._parse_arguments(ctx) + + if call_hooks: + yield from self.call_before_hooks(ctx) + + ctx.invoked_subcommand = None + try: + yield from self.callback(*ctx.args, **ctx.kwargs) + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + yield from self.call_after_hooks(ctx) + def error(self, coro): """A decorator that registers a coroutine as a local error handler. @@ -821,6 +856,44 @@ class Group(GroupMixin, Command): view.previous = previous yield from super().invoke(ctx) + @asyncio.coroutine + def reinvoke(self, ctx, *, call_hooks=False): + early_invoke = not self.invoke_without_command + if early_invoke: + ctx.command = self + yield from self._parse_arguments(ctx) + + if call_hooks: + yield from self.call_before_hooks(ctx) + + view = ctx.view + previous = view.index + view.skip_ws() + trigger = view.get_word() + + if trigger: + ctx.subcommand_passed = trigger + ctx.invoked_subcommand = self.all_commands.get(trigger, None) + + if early_invoke: + try: + yield from self.callback(*ctx.args, **ctx.kwargs) + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + yield from self.call_after_hooks(ctx) + + if trigger and ctx.invoked_subcommand: + ctx.invoked_with = trigger + yield from ctx.invoked_subcommand.reinvoke(ctx, call_hooks=call_hooks) + elif not early_invoke: + # undo the trigger parsing + view.index = previous + view.previous = previous + yield from super().reinvoke(ctx, call_hooks=call_hooks) + # Decorators def command(name=None, cls=None, **attrs):