From adbf2c720f192d20d6bd71bac55d1c5057a8baa1 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 19 Jun 2016 22:06:09 -0400 Subject: [PATCH] [commands] Add the concept of global checks. Global checks are checks that are executed before regular per-command checks except done to every command that the bot has registered. This allows you to have checks that apply to every command without having to override `on_message` or appending the check to every single command. --- discord/ext/commands/bot.py | 87 +++++++++++++++++++++++++++++++++++- discord/ext/commands/core.py | 5 ++- 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 4cbc82471..44a9f3521 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -208,6 +208,7 @@ class Bot(GroupMixin, discord.Client): self.extra_events = {} self.cogs = {} self.extensions = {} + self._checks = [] self.description = inspect.cleandoc(description) if description else '' self.pm_help = pm_help self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.') @@ -443,6 +444,70 @@ class Bot(GroupMixin, discord.Client): destination = _get_variable('_internal_channel') return self.send_typing(destination) + # global check registration + + def check(self): + """A decorator that adds a global check to the bot. + + A global check is similar to a :func:`check` that is applied + on a per command basis except it is run before any command checks + have been verified and applies to every command the bot has. + + .. warning:: + + This function must be a *regular* function and not a coroutine. + + Similar to a command :func:`check`\, this takes a single parameter + of type :class:`Context` and can only raise exceptions derived from + :exc:`CommandError`. + + Example + --------- + + .. code-block:: python + + @bot.check + def whitelist(ctx): + return ctx.message.author.id in my_whitelist + + """ + def decorator(func): + self.add_check(func) + return func + return decorator + + def add_check(self, func): + """Adds a global check to the bot. + + This is the non-decorator interface to :meth:`check`. + + Parameters + ----------- + func + The function that was used as a global check. + """ + self._checks.append(func) + + def remove_check(self, func): + """Removes a global check from the bot. + + This function is idempotent and will not raise an exception + if the function is not in the global checks. + + Parameters + ----------- + func + The function to remove from the global checks. + """ + + try: + self._checks.remove(func) + except ValueError: + pass + + def can_run(self, ctx): + return all(f(ctx) for f in self._checks) + # listener registration def add_listener(self, func, name=None): @@ -543,6 +608,9 @@ class Bot(GroupMixin, discord.Client): They are meant as a way to organize multiple relevant commands into a singular class that shares some state or no state at all. + The cog can also have a ``__check`` member function that allows + you to define a global check. See :meth:`check` for more info. + More information will be documented soon. Parameters @@ -552,6 +620,14 @@ class Bot(GroupMixin, discord.Client): """ self.cogs[type(cog).__name__] = cog + + try: + check = getattr(cog, '_{.__class__.__name__}__check'.format(cog)) + except AttributeError: + pass + else: + self.add_check(check) + members = inspect.getmembers(cog) for name, member in members: # register commands the cog has @@ -613,11 +689,20 @@ class Bot(GroupMixin, discord.Client): if name.startswith('on_'): self.remove_listener(member) + try: + check = getattr(cog, '_{0.__class__.__name__}__check'.format(cog)) + except AttributeError: + pass + else: + self.remove_check(check) + unloader_name = '_{0.__class__.__name__}__unload'.format(cog) try: - getattr(cog, unloader_name)() + unloader = getattr(cog, unloader_name) except AttributeError: pass + else: + unloader() del cog diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index ac91933db..330f26061 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -395,8 +395,11 @@ class Command: if self.no_pm and ctx.message.channel.is_private: raise NoPrivateMessage('This command cannot be used in private messages.') + if not ctx.bot.can_run(ctx): + raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self)) + if not self.can_run(ctx): - raise CheckFailure('The check functions for command {0.name} failed.'.format(self)) + raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) @asyncio.coroutine def invoke(self, ctx):