diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index ef7d9ca68..53aecdb1b 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -121,6 +121,59 @@ class Context(discord.abc.Messageable): ret = yield from command.callback(*arguments, **kwargs) return ret + @asyncio.coroutine + def reinvoke(self, *, call_hooks=False, restart=True): + """|coro| + + Calls the command again. + + This is similar to :meth:`~.Context.invoke` except that it bypasses + checks, cooldowns, and error handlers. + + .. note:: + + If you want to bypass :exc:`.UserInputError` derived exceptions, + it is recommended to use the regular :meth:`~.Context.invoke` + as it will work more naturally. After all, this will end up + using the old arguments the user has used and will thus just + fail again. + + Parameters + ------------ + call_hooks: bool + Whether to call the before and after invoke hooks. + restart: bool + Whether to start the call chain from the very beginning + or where we left off (i.e. the command that caused the error). + """ + cmd = self.command + view = self.view + if cmd is None: + raise ValueError('This context is not valid.') + + # some state to revert to when we're done + index, previous = view.index, view.previous + invoked_with = self.invoked_with + invoked_subcommand = self.invoked_subcommand + subcommand_passed = self.subcommand_passed + + if restart: + to_call = cmd.root_parent or cmd + view.index = len(self.prefix) + 1 + view.previous = 0 + else: + to_call = cmd + + try: + yield from to_call.reinvoke(self, call_hooks=call_hooks) + finally: + self.command = cmd + view.index = index + view.previous = previous + self.invoked_with = invoked_with + self.invoked_subcommand = invoked_subcommand + self.subcommand_passed = subcommand_passed + @property def valid(self): """Checks if the invocation context is valid to be invoked with.""" diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 29a0433c1..2662a31be 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -282,6 +282,24 @@ class Command: return ' '.join(reversed(entries)) + @property + def root_parent(self): + """Retrieves the root parent of this command. + + If the command has no parents then it returns ``None``. + + For example in commands ``?a b c test``, the root parent is + ``a``. + """ + entries = [] + command = self + while command.parent is not None: + command = command.parent + entries.append(command) + entries.append(None) + entries.reverse() + return entries[-1] + @property def qualified_name(self): """Retrieves the fully qualified command name. @@ -350,7 +368,6 @@ class Command: if not view.eof: raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) - @asyncio.coroutine def _verify_checks(self, ctx): if not self.enabled: @@ -407,7 +424,6 @@ class Command: def prepare(self, ctx): ctx.command = self yield from self._verify_checks(ctx) - yield from self._parse_arguments(ctx) if self._buckets.valid: bucket = self._buckets.get_bucket(ctx) @@ -415,6 +431,7 @@ class Command: if retry_after: raise CommandOnCooldown(bucket, retry_after) + yield from self._parse_arguments(ctx) yield from self.call_before_hooks(ctx) def reset_cooldown(self, ctx): @@ -440,6 +457,24 @@ class Command: injected = hooked_wrapped_callback(self, ctx, self.callback) yield from injected(*ctx.args, **ctx.kwargs) + @asyncio.coroutine + def reinvoke(self, ctx, *, call_hooks=False): + ctx.command = self + yield from self._parse_arguments(ctx) + + if call_hooks: + yield from self.call_before_hooks(ctx) + + ctx.invoked_subcommand = None + try: + yield from self.callback(*ctx.args, **ctx.kwargs) + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + yield from self.call_after_hooks(ctx) + def error(self, coro): """A decorator that registers a coroutine as a local error handler. @@ -821,6 +856,44 @@ class Group(GroupMixin, Command): view.previous = previous yield from super().invoke(ctx) + @asyncio.coroutine + def reinvoke(self, ctx, *, call_hooks=False): + early_invoke = not self.invoke_without_command + if early_invoke: + ctx.command = self + yield from self._parse_arguments(ctx) + + if call_hooks: + yield from self.call_before_hooks(ctx) + + view = ctx.view + previous = view.index + view.skip_ws() + trigger = view.get_word() + + if trigger: + ctx.subcommand_passed = trigger + ctx.invoked_subcommand = self.all_commands.get(trigger, None) + + if early_invoke: + try: + yield from self.callback(*ctx.args, **ctx.kwargs) + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + yield from self.call_after_hooks(ctx) + + if trigger and ctx.invoked_subcommand: + ctx.invoked_with = trigger + yield from ctx.invoked_subcommand.reinvoke(ctx, call_hooks=call_hooks) + elif not early_invoke: + # undo the trigger parsing + view.index = previous + view.previous = previous + yield from super().reinvoke(ctx, call_hooks=call_hooks) + # Decorators def command(name=None, cls=None, **attrs):