diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index d080d64a3..911664f46 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -85,7 +85,7 @@ def _default_help_command(ctx, *commands : str): # help by itself just lists our own commands. if len(commands) == 0: - pages = bot.formatter.format_help_for(ctx, bot) + pages = yield from bot.formatter.format_help_for(ctx, bot) elif len(commands) == 1: # try to see if it is a cog name name = _mention_pattern.sub(repl, commands[0]) @@ -98,7 +98,7 @@ def _default_help_command(ctx, *commands : str): yield from destination.send(bot.command_not_found.format(name)) return - pages = bot.formatter.format_help_for(ctx, command) + pages = yield from bot.formatter.format_help_for(ctx, command) else: name = _mention_pattern.sub(repl, commands[0]) command = bot.commands.get(name) @@ -117,7 +117,7 @@ def _default_help_command(ctx, *commands : str): yield from destination.send(bot.command_has_no_subcommands.format(command, key)) return - pages = bot.formatter.format_help_for(ctx, command) + pages = yield from bot.formatter.format_help_for(ctx, command) if bot.pm_help is None: characters = sum(map(lambda l: len(l), pages)) @@ -218,9 +218,9 @@ class BotBase(GroupMixin): on a per command basis except it is run before any command checks have been verified and applies to every command the bot has. - .. warning:: + .. info:: - This function must be a *regular* function and not a coroutine. + This function can either be a regular function or a coroutine. Similar to a command :func:`check`\, this takes a single parameter of type :class:`Context` and can only raise exceptions derived from @@ -268,8 +268,12 @@ class BotBase(GroupMixin): except ValueError: pass + @asyncio.coroutine def can_run(self, ctx): - return all(f(ctx) for f in self._checks) + if len(self._checks) == 0: + return True + + return (yield from discord.utils.async_all(f(ctx) for f in self._checks)) def before_invoke(self, coro): """A decorator that registers a coroutine as a pre-invoke hook. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 29ebd9d86..38b5ed6ac 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -342,6 +342,7 @@ class Command: raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) + @asyncio.coroutine def _verify_checks(self, ctx): if not self.enabled: raise DisabledCommand('{0.name} command is disabled'.format(self)) @@ -349,10 +350,7 @@ class Command: if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel): raise NoPrivateMessage('This command cannot be used in private messages.') - if not ctx.bot.can_run(ctx): - raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self)) - - if not self.can_run(ctx): + if not (yield from self.can_run(ctx)): raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) @asyncio.coroutine @@ -402,7 +400,7 @@ class Command: @asyncio.coroutine def prepare(self, ctx): ctx.command = self - self._verify_checks(ctx) + yield from self._verify_checks(ctx) yield from self._parse_arguments(ctx) if self._buckets.valid: @@ -533,14 +531,17 @@ class Command: return self.help.split('\n', 1)[0] return '' - def can_run(self, context): - """Checks if the command can be executed by checking all the predicates + @asyncio.coroutine + def can_run(self, ctx): + """|coro| + + Checks if the command can be executed by checking all the predicates inside the :attr:`checks` attribute. Parameters ----------- - context : :class:`Context` - The context of the command currently being invoked. + ctx: :class:`Context` + The ctx of the command currently being invoked. Returns -------- @@ -548,6 +549,9 @@ class Command: A boolean indicating if the command can be invoked. """ + if not (yield from ctx.bot.can_run(ctx)): + raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self)) + cog = self.instance if cog is not None: try: @@ -555,14 +559,16 @@ class Command: except AttributeError: pass else: - if not local_check(context): + ret = yield from discord.utils.maybe_coroutine(local_check, ctx) + if not ret: return False predicates = self.checks if not predicates: # since we have no checks, then we just return True. return True - return all(predicate(context) for predicate in predicates) + + return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates)) class GroupMixin: """A mixin that implements common functionality for classes that behave @@ -855,6 +861,10 @@ def check(predicate): will be propagated while those subclassed will be sent to :func:`on_command_error`. + .. info:: + + These functions can either be regular functions or coroutines. + Parameters ----------- predicate diff --git a/discord/ext/commands/formatter.py b/discord/ext/commands/formatter.py index 5fb51cfe7..070578ec4 100644 --- a/discord/ext/commands/formatter.py +++ b/discord/ext/commands/formatter.py @@ -26,9 +26,11 @@ DEALINGS IN THE SOFTWARE. import itertools import inspect +import asyncio from .core import GroupMixin, Command from .errors import CommandError +# from discord.iterators import _FilteredAsyncIterator # help -> shows info of bot on top/bottom and lists subcommands # help command -> shows detailed info of command @@ -227,6 +229,7 @@ class HelpFormatter: return "Type {0}{1} command for more info on a command.\n" \ "You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name) + @asyncio.coroutine def filter_command_list(self): """Returns a filtered list of commands based on the two attributes provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also @@ -238,8 +241,9 @@ class HelpFormatter: An iterable with the filter being applied. The resulting value is a (key, value) tuple of the command name and the command itself. """ - def predicate(tuple): - cmd = tuple[1] + + def sane_no_suspension_point_predicate(tup): + cmd = tup[1] if self.is_cog(): # filter commands that don't exist to this cog. if cmd.instance is not self.command: @@ -248,18 +252,31 @@ class HelpFormatter: if cmd.hidden and not self.show_hidden: return False - if self.show_check_failure: - # we don't wanna bother doing the checks if the user does not - # care about them, so just return true. - return True + return True + + @asyncio.coroutine + def predicate(tup): + if sane_no_suspension_point_predicate(tup) is False: + return False + cmd = tup[1] try: - return cmd.can_run(self.context) and self.context.bot.can_run(self.context) + return (yield from cmd.can_run(self.context)) except CommandError: return False iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items() - return filter(predicate, iterator) + if not self.show_check_failure: + return filter(sane_no_suspension_point_predicate, iterator) + + # Gotta run every check and verify it + ret = [] + for elem in iterator: + valid = yield from predicate(elem) + if valid: + ret.append(elem) + + return ret def _add_subcommands_to_page(self, max_width, commands): for name, command in commands: @@ -271,6 +288,7 @@ class HelpFormatter: shortened = self.shorten(entry) self._paginator.add_line(shortened) + @asyncio.coroutine def format_help_for(self, context, command_or_bot): """Formats the help page and handles the actual heavy lifting of how the help command looks like. To change the behaviour, override the @@ -290,8 +308,9 @@ class HelpFormatter: """ self.context = context self.command = command_or_bot - return self.format() + return (yield from self.format()) + @asyncio.coroutine def format(self): """Handles the actual behaviour involved with formatting. @@ -334,18 +353,19 @@ class HelpFormatter: # last place sorting position. return cog + ':' if cog is not None else '\u200bNo Category:' + filtered = yield from self.filter_command_list() if self.is_bot(): - data = sorted(self.filter_command_list(), key=category) + data = sorted(filtered, key=category) for category, commands in itertools.groupby(data, key=category): # there simply is no prettier way of doing this. - commands = list(commands) + commands = sorted(commands) if len(commands) > 0: self._paginator.add_line(category) self._add_subcommands_to_page(max_width, commands) else: self._paginator.add_line('Commands:') - self._add_subcommands_to_page(max_width, self.filter_command_list()) + self._add_subcommands_to_page(max_width, sorted(filtered)) # add the ending note self._paginator.add_line() diff --git a/discord/iterators.py b/discord/iterators.py index 31d72569a..97d9c27b5 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -30,18 +30,11 @@ import aiohttp import datetime from .errors import NoMoreItems -from .utils import time_snowflake +from .utils import time_snowflake, maybe_coroutine from .object import Object PY35 = sys.version_info >= (3, 5) -@asyncio.coroutine -def _probably_coroutine(f, e): - if asyncio.iscoroutinefunction(f): - return (yield from f(e)) - else: - return f(e) - class _AsyncIterator: __slots__ = () @@ -67,7 +60,7 @@ class _AsyncIterator: except NoMoreItems: return None - ret = yield from _probably_coroutine(predicate, elem) + ret = yield from maybe_coroutine(predicate, elem) if ret: return elem @@ -114,7 +107,7 @@ class _MappedAsyncIterator(_AsyncIterator): def get(self): # this raises NoMoreItems and will propagate appropriately item = yield from self.iterator.get() - return (yield from _probably_coroutine(self.func, item)) + return (yield from maybe_coroutine(self.func, item)) class _FilteredAsyncIterator(_AsyncIterator): def __init__(self, iterator, predicate): @@ -132,7 +125,7 @@ class _FilteredAsyncIterator(_AsyncIterator): while True: # propagate NoMoreItems similar to _MappedAsyncIterator item = yield from getter() - ret = yield from _probably_coroutine(pred, item) + ret = yield from maybe_coroutine(pred, item) if ret: return item diff --git a/discord/utils.py b/discord/utils.py index 1db8e4e09..fe1129d70 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -260,3 +260,19 @@ def _bytes_to_base64_data(data): def to_json(obj): return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) +@asyncio.coroutine +def maybe_coroutine(f, e): + if asyncio.iscoroutinefunction(f): + return (yield from f(e)) + else: + return f(e) + +@asyncio.coroutine +def async_all(gen): + check = asyncio.iscoroutine + for elem in gen: + if check(elem): + elem = yield from elem + if not elem: + return False + return True