Browse Source

Implement async checks. Fixes #380.

pull/1278/head
Rapptz 8 years ago
parent
commit
47ef657fbd
  1. 16
      discord/ext/commands/bot.py
  2. 32
      discord/ext/commands/core.py
  3. 44
      discord/ext/commands/formatter.py
  4. 15
      discord/iterators.py
  5. 16
      discord/utils.py

16
discord/ext/commands/bot.py

@ -85,7 +85,7 @@ def _default_help_command(ctx, *commands : str):
# help by itself just lists our own commands.
if len(commands) == 0:
pages = bot.formatter.format_help_for(ctx, bot)
pages = yield from bot.formatter.format_help_for(ctx, bot)
elif len(commands) == 1:
# try to see if it is a cog name
name = _mention_pattern.sub(repl, commands[0])
@ -98,7 +98,7 @@ def _default_help_command(ctx, *commands : str):
yield from destination.send(bot.command_not_found.format(name))
return
pages = bot.formatter.format_help_for(ctx, command)
pages = yield from bot.formatter.format_help_for(ctx, command)
else:
name = _mention_pattern.sub(repl, commands[0])
command = bot.commands.get(name)
@ -117,7 +117,7 @@ def _default_help_command(ctx, *commands : str):
yield from destination.send(bot.command_has_no_subcommands.format(command, key))
return
pages = bot.formatter.format_help_for(ctx, command)
pages = yield from bot.formatter.format_help_for(ctx, command)
if bot.pm_help is None:
characters = sum(map(lambda l: len(l), pages))
@ -218,9 +218,9 @@ class BotBase(GroupMixin):
on a per command basis except it is run before any command checks
have been verified and applies to every command the bot has.
.. warning::
.. info::
This function must be a *regular* function and not a coroutine.
This function can either be a regular function or a coroutine.
Similar to a command :func:`check`\, this takes a single parameter
of type :class:`Context` and can only raise exceptions derived from
@ -268,8 +268,12 @@ class BotBase(GroupMixin):
except ValueError:
pass
@asyncio.coroutine
def can_run(self, ctx):
return all(f(ctx) for f in self._checks)
if len(self._checks) == 0:
return True
return (yield from discord.utils.async_all(f(ctx) for f in self._checks))
def before_invoke(self, coro):
"""A decorator that registers a coroutine as a pre-invoke hook.

32
discord/ext/commands/core.py

@ -342,6 +342,7 @@ class Command:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
@asyncio.coroutine
def _verify_checks(self, ctx):
if not self.enabled:
raise DisabledCommand('{0.name} command is disabled'.format(self))
@ -349,10 +350,7 @@ class Command:
if self.no_pm and not isinstance(ctx.channel, discord.abc.GuildChannel):
raise NoPrivateMessage('This command cannot be used in private messages.')
if not ctx.bot.can_run(ctx):
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
if not self.can_run(ctx):
if not (yield from self.can_run(ctx)):
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
@asyncio.coroutine
@ -402,7 +400,7 @@ class Command:
@asyncio.coroutine
def prepare(self, ctx):
ctx.command = self
self._verify_checks(ctx)
yield from self._verify_checks(ctx)
yield from self._parse_arguments(ctx)
if self._buckets.valid:
@ -533,14 +531,17 @@ class Command:
return self.help.split('\n', 1)[0]
return ''
def can_run(self, context):
"""Checks if the command can be executed by checking all the predicates
@asyncio.coroutine
def can_run(self, ctx):
"""|coro|
Checks if the command can be executed by checking all the predicates
inside the :attr:`checks` attribute.
Parameters
-----------
context : :class:`Context`
The context of the command currently being invoked.
ctx: :class:`Context`
The ctx of the command currently being invoked.
Returns
--------
@ -548,6 +549,9 @@ class Command:
A boolean indicating if the command can be invoked.
"""
if not (yield from ctx.bot.can_run(ctx)):
raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
cog = self.instance
if cog is not None:
try:
@ -555,14 +559,16 @@ class Command:
except AttributeError:
pass
else:
if not local_check(context):
ret = yield from discord.utils.maybe_coroutine(local_check, ctx)
if not ret:
return False
predicates = self.checks
if not predicates:
# since we have no checks, then we just return True.
return True
return all(predicate(context) for predicate in predicates)
return (yield from discord.utils.async_all(predicate(ctx) for predicate in predicates))
class GroupMixin:
"""A mixin that implements common functionality for classes that behave
@ -855,6 +861,10 @@ def check(predicate):
will be propagated while those subclassed will be sent to
:func:`on_command_error`.
.. info::
These functions can either be regular functions or coroutines.
Parameters
-----------
predicate

44
discord/ext/commands/formatter.py

@ -26,9 +26,11 @@ DEALINGS IN THE SOFTWARE.
import itertools
import inspect
import asyncio
from .core import GroupMixin, Command
from .errors import CommandError
# from discord.iterators import _FilteredAsyncIterator
# help -> shows info of bot on top/bottom and lists subcommands
# help command -> shows detailed info of command
@ -227,6 +229,7 @@ class HelpFormatter:
return "Type {0}{1} command for more info on a command.\n" \
"You can also type {0}{1} category for more info on a category.".format(self.clean_prefix, command_name)
@asyncio.coroutine
def filter_command_list(self):
"""Returns a filtered list of commands based on the two attributes
provided, :attr:`show_check_failure` and :attr:`show_hidden`. Also
@ -238,8 +241,9 @@ class HelpFormatter:
An iterable with the filter being applied. The resulting value is
a (key, value) tuple of the command name and the command itself.
"""
def predicate(tuple):
cmd = tuple[1]
def sane_no_suspension_point_predicate(tup):
cmd = tup[1]
if self.is_cog():
# filter commands that don't exist to this cog.
if cmd.instance is not self.command:
@ -248,18 +252,31 @@ class HelpFormatter:
if cmd.hidden and not self.show_hidden:
return False
if self.show_check_failure:
# we don't wanna bother doing the checks if the user does not
# care about them, so just return true.
return True
return True
@asyncio.coroutine
def predicate(tup):
if sane_no_suspension_point_predicate(tup) is False:
return False
cmd = tup[1]
try:
return cmd.can_run(self.context) and self.context.bot.can_run(self.context)
return (yield from cmd.can_run(self.context))
except CommandError:
return False
iterator = self.command.commands.items() if not self.is_cog() else self.context.bot.commands.items()
return filter(predicate, iterator)
if not self.show_check_failure:
return filter(sane_no_suspension_point_predicate, iterator)
# Gotta run every check and verify it
ret = []
for elem in iterator:
valid = yield from predicate(elem)
if valid:
ret.append(elem)
return ret
def _add_subcommands_to_page(self, max_width, commands):
for name, command in commands:
@ -271,6 +288,7 @@ class HelpFormatter:
shortened = self.shorten(entry)
self._paginator.add_line(shortened)
@asyncio.coroutine
def format_help_for(self, context, command_or_bot):
"""Formats the help page and handles the actual heavy lifting of how
the help command looks like. To change the behaviour, override the
@ -290,8 +308,9 @@ class HelpFormatter:
"""
self.context = context
self.command = command_or_bot
return self.format()
return (yield from self.format())
@asyncio.coroutine
def format(self):
"""Handles the actual behaviour involved with formatting.
@ -334,18 +353,19 @@ class HelpFormatter:
# last place sorting position.
return cog + ':' if cog is not None else '\u200bNo Category:'
filtered = yield from self.filter_command_list()
if self.is_bot():
data = sorted(self.filter_command_list(), key=category)
data = sorted(filtered, key=category)
for category, commands in itertools.groupby(data, key=category):
# there simply is no prettier way of doing this.
commands = list(commands)
commands = sorted(commands)
if len(commands) > 0:
self._paginator.add_line(category)
self._add_subcommands_to_page(max_width, commands)
else:
self._paginator.add_line('Commands:')
self._add_subcommands_to_page(max_width, self.filter_command_list())
self._add_subcommands_to_page(max_width, sorted(filtered))
# add the ending note
self._paginator.add_line()

15
discord/iterators.py

@ -30,18 +30,11 @@ import aiohttp
import datetime
from .errors import NoMoreItems
from .utils import time_snowflake
from .utils import time_snowflake, maybe_coroutine
from .object import Object
PY35 = sys.version_info >= (3, 5)
@asyncio.coroutine
def _probably_coroutine(f, e):
if asyncio.iscoroutinefunction(f):
return (yield from f(e))
else:
return f(e)
class _AsyncIterator:
__slots__ = ()
@ -67,7 +60,7 @@ class _AsyncIterator:
except NoMoreItems:
return None
ret = yield from _probably_coroutine(predicate, elem)
ret = yield from maybe_coroutine(predicate, elem)
if ret:
return elem
@ -114,7 +107,7 @@ class _MappedAsyncIterator(_AsyncIterator):
def get(self):
# this raises NoMoreItems and will propagate appropriately
item = yield from self.iterator.get()
return (yield from _probably_coroutine(self.func, item))
return (yield from maybe_coroutine(self.func, item))
class _FilteredAsyncIterator(_AsyncIterator):
def __init__(self, iterator, predicate):
@ -132,7 +125,7 @@ class _FilteredAsyncIterator(_AsyncIterator):
while True:
# propagate NoMoreItems similar to _MappedAsyncIterator
item = yield from getter()
ret = yield from _probably_coroutine(pred, item)
ret = yield from maybe_coroutine(pred, item)
if ret:
return item

16
discord/utils.py

@ -260,3 +260,19 @@ def _bytes_to_base64_data(data):
def to_json(obj):
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
@asyncio.coroutine
def maybe_coroutine(f, e):
if asyncio.iscoroutinefunction(f):
return (yield from f(e))
else:
return f(e)
@asyncio.coroutine
def async_all(gen):
check = asyncio.iscoroutine
for elem in gen:
if check(elem):
elem = yield from elem
if not elem:
return False
return True

Loading…
Cancel
Save