diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index fbdaa75ab..7b6cd6754 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -33,6 +33,16 @@ from .view import StringView from .context import Context from .errors import CommandNotFound +def _get_variable(name): + stack = inspect.stack() + try: + for frames in stack: + current_locals = frames[0].f_locals + if name in current_locals: + return current_locals[name] + finally: + del stack + def when_mentioned(bot, msg): """A callable that implements a command prefix equivalent to being mentioned, e.g. ``@bot ``.""" @@ -71,13 +81,6 @@ class Bot(GroupMixin, discord.Client): # internal helpers - def _get_variable(self, name): - stack = inspect.stack() - for frames in stack: - current_locals = frames[0].f_locals - if name in current_locals: - return current_locals[name] - def _get_prefix(self, message): prefix = self.command_prefix if callable(prefix): @@ -122,7 +125,7 @@ class Bot(GroupMixin, discord.Client): content : str The content to pass to :class:`Client.send_message` """ - destination = self._get_variable('_internal_channel') + destination = _get_variable('_internal_channel') result = yield from self.send_message(destination, content) return result @@ -141,7 +144,7 @@ class Bot(GroupMixin, discord.Client): content : str The content to pass to :class:`Client.send_message` """ - destination = self._get_variable('_internal_author') + destination = _get_variable('_internal_author') result = yield from self.send_message(destination, content) return result @@ -161,8 +164,8 @@ class Bot(GroupMixin, discord.Client): content : str The content to pass to :class:`Client.send_message` """ - author = self._get_variable('_internal_author') - destination = self._get_variable('_internal_channel') + author = _get_variable('_internal_author') + destination = _get_variable('_internal_channel') fmt = '{0.mention}, {1}'.format(author, str(content)) result = yield from self.send_message(destination, fmt) return result @@ -184,7 +187,7 @@ class Bot(GroupMixin, discord.Client): name The second parameter to pass to :meth:`Client.send_file` """ - destination = self._get_variable('_internal_channel') + destination = _get_variable('_internal_channel') result = yield from self.send_file(destination, fp, name) return result @@ -202,7 +205,7 @@ class Bot(GroupMixin, discord.Client): --------- The :meth:`Client.send_typing` function. """ - destination = self._get_variable('_internal_channel') + destination = _get_variable('_internal_channel') yield from self.send_typing(destination) # listener registration diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 737df74d9..14135b084 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -28,7 +28,7 @@ import asyncio import inspect import re import discord -from functools import partial +import functools from .errors import * from .view import quoted_word @@ -36,6 +36,17 @@ from .view import quoted_word __all__ = [ 'Command', 'Group', 'GroupMixin', 'command', 'group', 'has_role', 'has_permissions', 'has_any_role', 'check' ] +def inject_context(ctx, coro): + @functools.wraps(coro) + @asyncio.coroutine + def wrapped(*args, **kwargs): + _internal_channel = ctx.message.channel + _internal_author = ctx.message.author + + ret = yield from coro(*args, **kwargs) + return ret + return wrapped + def _convert_to_bool(argument): lowered = argument.lower() if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): @@ -103,10 +114,11 @@ class Command: except AttributeError: return + injected = inject_context(ctx, coro) if self.instance is not None: - discord.utils.create_task(coro(self.instance, error, ctx), loop=ctx.bot.loop) + discord.utils.create_task(injected(self.instance, error, ctx), loop=ctx.bot.loop) else: - discord.utils.create_task(coro(error, ctx), loop=ctx.bot.loop) + discord.utils.create_task(injected(error, ctx), loop=ctx.bot.loop) def _receive_item(self, message, argument, regex, receiver, generator): match = re.match(regex, argument) @@ -263,7 +275,8 @@ class Command: return if self._parse_arguments(ctx): - yield from self.callback(*ctx.args, **ctx.kwargs) + injected = inject_context(ctx, self.callback) + yield from injected(*ctx.args, **ctx.kwargs) def error(self, coro): """A decorator that registers a coroutine as a local error handler. @@ -425,7 +438,8 @@ class Group(GroupMixin, Command): if trigger in self.commands: ctx.invoked_subcommand = self.commands[trigger] - yield from self.callback(*ctx.args, **ctx.kwargs) + injected = inject_context(ctx, self.callback) + yield from injected(*ctx.args, **ctx.kwargs) if ctx.invoked_subcommand: ctx.invoked_with = trigger @@ -616,7 +630,7 @@ def has_any_role(*names): if ch.is_private: return False - getter = partial(discord.utils.get, msg.author.roles) + getter = functools.partial(discord.utils.get, msg.author.roles) return any(getter(name=name) is not None for name in names) return check(predicate)