From ae3dac0d5971e156f04fecafe5cc1773b019f93f Mon Sep 17 00:00:00 2001 From: Rapptz Date: Mon, 6 Jan 2020 22:03:27 -0500 Subject: [PATCH] [commands] Add check_any check to OR together various checks --- discord/ext/commands/core.py | 70 ++++++++++++++++++++++++++++++++++ discord/ext/commands/errors.py | 21 ++++++++++ docs/ext/commands/api.rst | 2 + 3 files changed, 93 insertions(+) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index f29ada945..5a7093d83 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -48,6 +48,7 @@ __all__ = ( 'has_permissions', 'has_any_role', 'check', + 'check_any', 'bot_has_role', 'bot_has_permissions', 'bot_has_any_role', @@ -1379,6 +1380,75 @@ def check(predicate): decorator.predicate = predicate return decorator +def check_any(*checks): + """A :func:`check` that is added that checks if any of the checks passed + will pass, i.e. using logical OR. + + If all checks fail then :exc:`.CheckAnyFailure` is raised to signal the failure. + It inherits from :exc:`.CheckFailure`. + + .. note:: + + The ``predicate`` attribute for this function **is** a coroutine. + + .. versionadded:: 1.3.0 + + Parameters + ------------ + \*checks: Callable[[:class:`Context`], :class:`bool`] + An argument list of checks that have been decorated with + the :func:`check` decorator. + + Raises + ------- + TypeError + A check passed has not been decorated with the :func:`check` + decorator. + + Examples + --------- + + Creating a basic check to see if it's the bot owner or + the server owner: + + .. code-block:: python3 + + def is_guild_owner(): + def predicate(ctx): + return ctx.guild is not None and ctx.guild.owner_id == ctx.author.id + return commands.check(predicate) + + @bot.command() + @commands.check_any(commands.is_owner(), is_guild_owner()) + async def only_for_owners(ctx): + await ctx.send('Hello mister owner!') + """ + + unwrapped = [] + for wrapped in checks: + try: + pred = wrapped.predicate + except AttributeError: + raise TypeError('%r must be wrapped by commands.check decorator' % wrapped) from None + else: + unwrapped.append(pred) + + async def predicate(ctx): + errors = [] + maybe = discord.utils.maybe_coroutine + for func in unwrapped: + try: + value = await maybe(func, ctx) + except CheckFailure as e: + errors.append(e) + else: + if value: + return True + # if we're here, all checks failed + raise CheckAnyFailure(unwrapped, errors) + + return check(predicate) + def has_role(item): """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index fa92c0493..71a0098a4 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -34,6 +34,7 @@ __all__ = ( 'PrivateMessageOnly', 'NoPrivateMessage', 'CheckFailure', + 'CheckAnyFailure', 'CommandNotFound', 'DisabledCommand', 'CommandInvokeError', @@ -153,6 +154,26 @@ class CheckFailure(CommandError): """ pass +class CheckAnyFailure(CheckFailure): + """Exception raised when all predicates in :func:`check_any` fail. + + This inherits from :exc:`CheckFailure`. + + .. versionadded:: 1.3 + + Attributes + ------------ + errors: List[:class:`CheckFailure`] + A list of errors that were caught during execution. + checks: List[Callable[[:class:`Context`], :class:`bool`]] + A list of check predicates that failed. + """ + + def __init__(self, checks, errors): + self.checks = checks + self.errors = errors + super().__init__('You do not have permission to run this command.') + class PrivateMessageOnly(CheckFailure): """Exception raised when an operation does not work outside of private message contexts. diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 71efb2b52..3672823a3 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -118,6 +118,8 @@ Checks .. autofunction:: discord.ext.commands.check +.. autofunction:: discord.ext.commands.check_any + .. autofunction:: discord.ext.commands.has_role .. autofunction:: discord.ext.commands.has_permissions