diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 1913a3198..d080d64a3 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -136,6 +136,8 @@ class BotBase(GroupMixin): self.cogs = {} self.extensions = {} self._checks = [] + self._before_invoke = None + self._after_invoke = None self.description = inspect.cleandoc(description) if description else '' self.pm_help = pm_help self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.') @@ -269,6 +271,71 @@ class BotBase(GroupMixin): def can_run(self, ctx): return all(f(ctx) for f in self._checks) + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`Context`. + + .. note:: + + The :meth:`before_invoke` and :meth:`after_invoke` hooks are + only called if all checks and argument parsing procedures pass + without error. If any check or argument parsing procedures fail + then the hooks are not called. + + Parameters + ----------- + coro + The coroutine to register as the pre-invoke hook. + + Raises + ------- + discord.ClientException + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException('The error handler must be a coroutine.') + + self._before_invoke = coro + return coro + + def after_invoke(self, coro): + """A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`Context`. + + .. note:: + + Similar to :meth:`before_invoke`\, this is not called unless + checks and argument parsing procedures succeed. This hook is, + however, **always** called regardless of the internal command + callback raising an error (i.e. :exc:`CommandInvokeError`\). + This makes it ideal for clean-up scenarios. + + Parameters + ----------- + coro + The coroutine to register as the post-invoke hook. + + Raises + ------- + discord.ClientException + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException('The error handler must be a coroutine.') + + self._after_invoke = coro + return coro + # listener registration def add_listener(self, func, name=None): diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 51a13f94a..d1a91133b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -52,6 +52,21 @@ def wrap_callback(coro): return ret return wrapped +def hooked_wrapped_callback(command, ctx, coro): + @functools.wraps(coro) + @asyncio.coroutine + def wrapped(*args, **kwargs): + try: + ret = yield from coro(*args, **kwargs) + except CommandError: + raise + except Exception as e: + raise CommandInvokeError(e) from e + finally: + yield from command.call_after_hooks(ctx) + return ret + return wrapped + def _convert_to_bool(argument): lowered = argument.lower() if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): @@ -144,6 +159,8 @@ class Command: self.instance = None self.parent = None self._buckets = CooldownMapping(kwargs.get('cooldown')) + self._before_invoke = None + self._after_invoke = None @asyncio.coroutine def dispatch_error(self, error, ctx): @@ -335,6 +352,50 @@ class Command: if not self.can_run(ctx): raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) + @asyncio.coroutine + def call_before_hooks(self, ctx): + # now that we're done preparing we can call the pre-command hooks + # first, call the command local hook: + cog = self.instance + if self._before_invoke is not None: + if cog is None: + yield from self._before_invoke(ctx) + else: + yield from self._before_invoke(cog, ctx) + + # call the cog local hook if applicable: + try: + hook = getattr(cog, '_{0.__class__.__name__}__before_invoke'.format(cog)) + except AttributeError: + pass + else: + yield from hook(ctx) + + # call the bot global hook if necessary + hook = ctx.bot._before_invoke + if hook is not None: + yield from hook(ctx) + + @asyncio.coroutine + def call_after_hooks(self, ctx): + cog = self.instance + if self._after_invoke is not None: + if cog is None: + yield from self._after_invoke(ctx) + else: + yield from self._after_invoke(cog, ctx) + + try: + hook = getattr(cog, '_{0.__class__.__name__}__after_invoke'.format(cog)) + except AttributeError: + pass + else: + yield from hook(ctx) + + hook = ctx.bot._after_invoke + if hook is not None: + yield from hook(ctx) + @asyncio.coroutine def prepare(self, ctx): ctx.command = self @@ -347,6 +408,8 @@ class Command: if retry_after: raise CommandOnCooldown(bucket, retry_after) + yield from self.call_before_hooks(ctx) + def reset_cooldown(self, ctx): """Resets the cooldown on this command. @@ -367,7 +430,7 @@ class Command: # since we're in a regular command (and not a group) then # the invoked subcommand is None. ctx.invoked_subcommand = None - injected = wrap_callback(self.callback) + injected = hooked_wrapped_callback(self, ctx, self.callback) yield from injected(*ctx.args, **ctx.kwargs) def error(self, coro): @@ -394,6 +457,60 @@ class Command: self.on_error = coro return coro + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before :meth:`invoke` is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`Context`. + + See :meth:`Bot.before_invoke` for more info. + + Parameters + ----------- + coro + The coroutine to register as the pre-invoke hook. + + Raises + ------- + discord.ClientException + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException('The error handler must be a coroutine.') + + self._before_invoke = coro + return coro + + def after_invoke(self, coro): + """A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after :meth:`invoke` is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`Context`. + + See :meth:`Bot.after_invoke` for more info. + + Parameters + ----------- + coro + The coroutine to register as the post-invoke hook. + + Raises + ------- + discord.ClientException + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException('The error handler must be a coroutine.') + + self._after_invoke = coro + return coro + @property def cog_name(self): """The name of the cog this command belongs to. None otherwise.""" @@ -610,7 +727,7 @@ class Group(GroupMixin, Command): ctx.invoked_subcommand = self.commands.get(trigger, None) if early_invoke: - injected = wrap_callback(self.callback) + injected = hooked_wrapped_callback(self, ctx, self.callback) yield from injected(*ctx.args, **ctx.kwargs) if trigger and ctx.invoked_subcommand: