From d7478425ca0e52d57ff08f59759ffcc072712e7e Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 10 May 2017 21:30:41 -0400 Subject: [PATCH] [commands] Converter.convert is always a coroutine. Along with this change comes with the removal of Converter.prepare and adding two arguments to Converter.convert, the context and the argument. I suppose an added benefit is that you don't have to do attribute access since it's a local variable. --- discord/ext/commands/converter.py | 149 ++++++++++++++++-------------- discord/ext/commands/core.py | 12 +-- 2 files changed, 83 insertions(+), 78 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index c3d1c787f..36e53fd22 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -53,46 +53,52 @@ class Converter: special cased ``discord`` classes. Classes that derive from this should override the :meth:`convert` method - to do its conversion logic. This method could be a coroutine or a regular - function. - - Before the convert method is called, :meth:`prepare` is called. This - method must set the attributes below if overwritten. - - Attributes - ----------- - ctx: :class:`Context` - The invocation context that the argument is being used in. - argument: str - The argument that is being converted. + to do its conversion logic. This method must be a coroutine. """ - def prepare(self, ctx, argument): - self.ctx = ctx - self.argument = argument - def convert(self): + @asyncio.coroutine + def convert(self, ctx, argument): + """|coro| + + The method to override to do conversion logic. + + This can either be a coroutine or a regular function. + + If an error is found while converting, it is recommended to + raise a :class:`CommandError` derived exception as it will + properly propagate to the error handlers. + + Parameters + ----------- + ctx: :class:`Context` + The invocation context that the argument is being used in. + argument: str + The argument that is being converted. + """ raise NotImplementedError('Derived classes need to implement this.') class IDConverter(Converter): def __init__(self): self._id_regex = re.compile(r'([0-9]{15,21})$') + super().__init__() - def _get_id_match(self): - return self._id_regex.match(self.argument) + def _get_id_match(self, argument): + return self._id_regex.match(argument) class MemberConverter(IDConverter): - def convert(self): - message = self.ctx.message - bot = self.ctx.bot - match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument) + @asyncio.coroutine + def convert(self, ctx, argument): + message = ctx.message + bot = ctx.bot + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) guild = message.guild result = None if match is None: # not a mention... if guild: - result = guild.get_member_named(self.argument) + result = guild.get_member_named(argument) else: - result = _get_from_guilds(bot, 'get_member_named', self.argument) + result = _get_from_guilds(bot, 'get_member_named', argument) else: user_id = int(match.group(1)) if guild: @@ -101,21 +107,22 @@ class MemberConverter(IDConverter): result = _get_from_guilds(bot, 'get_member', user_id) if result is None: - raise BadArgument('Member "{}" not found'.format(self.argument)) + raise BadArgument('Member "{}" not found'.format(argument)) return result class UserConverter(IDConverter): - def convert(self): - match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument) + @asyncio.coroutine + def convert(self, ctx, argument): + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) result = None - state = self.ctx._state + state = ctx._state if match is not None: user_id = int(match.group(1)) - result = self.ctx.bot.get_user(user_id) + result = ctx.bot.get_user(user_id) else: - arg = self.argument + arg = argument # check for discriminator if it exists if len(arg) > 5 and arg[-5] == '#': discrim = arg[-4:] @@ -129,25 +136,26 @@ class UserConverter(IDConverter): result = discord.utils.find(predicate, state._users.values()) if result is None: - raise BadArgument('User "{}" not found'.format(self.argument)) + raise BadArgument('User "{}" not found'.format(argument)) return result class TextChannelConverter(IDConverter): - def convert(self): - bot = self.ctx.bot + @asyncio.coroutine + def convert(self, ctx, argument): + bot = ctx.bot - match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument) + match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) result = None - guild = self.ctx.guild + guild = ctx.guild if match is None: # not a mention if guild: - result = discord.utils.get(guild.text_channels, name=self.argument) + result = discord.utils.get(guild.text_channels, name=argument) else: def check(c): - return isinstance(c, discord.TextChannel) and c.name == self.argument + return isinstance(c, discord.TextChannel) and c.name == argument result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -157,25 +165,25 @@ class TextChannelConverter(IDConverter): result = _get_from_guilds(bot, 'get_channel', channel_id) if not isinstance(result, discord.TextChannel): - raise BadArgument('Channel "{}" not found.'.format(self.argument)) + raise BadArgument('Channel "{}" not found.'.format(argument)) return result class VoiceChannelConverter(IDConverter): - def convert(self): - bot = self.ctx.bot - - match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument) + @asyncio.coroutine + def convert(self, ctx, argument): + bot = ctx.bot + match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) result = None - guild = self.ctx.guild + guild = ctx.guild if match is None: # not a mention if guild: - result = discord.utils.get(guild.voice_channels, name=self.argument) + result = discord.utils.get(guild.voice_channels, name=argument) else: def check(c): - return isinstance(c, discord.VoiceChannel) and c.name == self.argument + return isinstance(c, discord.VoiceChannel) and c.name == argument result = discord.utils.find(check, bot.get_all_channels()) else: channel_id = int(match.group(1)) @@ -185,13 +193,14 @@ class VoiceChannelConverter(IDConverter): result = _get_from_guilds(bot, 'get_channel', channel_id) if not isinstance(result, discord.VoiceChannel): - raise BadArgument('Channel "{}" not found.'.format(self.argument)) + raise BadArgument('Channel "{}" not found.'.format(argument)) return result class ColourConverter(Converter): - def convert(self): - arg = self.argument.replace('0x', '').lower() + @asyncio.coroutine + def convert(self, ctx, argument): + arg = argument.replace('0x', '').lower() if arg[0] == '#': arg = arg[1:] @@ -205,47 +214,48 @@ class ColourConverter(Converter): return method() class RoleConverter(IDConverter): - def convert(self): - guild = self.ctx.message.guild + @asyncio.coroutine + def convert(self, ctx, argument): + guild = ctx.message.guild if not guild: raise NoPrivateMessage() - match = self._get_id_match() or re.match(r'<@&([0-9]+)>$', self.argument) - params = dict(id=int(match.group(1))) if match else dict(name=self.argument) + match = self._get_id_match(argument) or re.match(r'<@&([0-9]+)>$', argument) + params = dict(id=int(match.group(1))) if match else dict(name=argument) result = discord.utils.get(guild.roles, **params) if result is None: - raise BadArgument('Role "{}" not found.'.format(self.argument)) + raise BadArgument('Role "{}" not found.'.format(argument)) return result class GameConverter(Converter): - def convert(self): - return discord.Game(name=self.argument) + @asyncio.coroutine + def convert(self, ctx, argument): + return discord.Game(name=argument) class InviteConverter(Converter): @asyncio.coroutine - def convert(self): + def convert(self, ctx, argument): try: - invite = yield from self.ctx.bot.get_invite(self.argument) + invite = yield from ctx.bot.get_invite(argument) return invite except Exception as e: raise BadArgument('Invite is invalid or expired') from e class EmojiConverter(IDConverter): @asyncio.coroutine - def convert(self): - message = self.ctx.message - bot = self.ctx.bot - - match = self._get_id_match() or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', self.argument) + def convert(self, ctx, argument): + match = self._get_id_match(argument) or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', argument) result = None - guild = message.guild + bot = ctx.bot + guild = ctx.guild + if match is None: # Try to get the emoji by name. Try local guild first. if guild: - result = discord.utils.get(guild.emojis, name=self.argument) + result = discord.utils.get(guild.emojis, name=argument) if result is None: - result = discord.utils.get(bot.emojis, name=self.argument) + result = discord.utils.get(bot.emojis, name=argument) else: emoji_id = int(match.group(1)) @@ -257,7 +267,7 @@ class EmojiConverter(IDConverter): result = discord.utils.get(bot.emojis, id=emoji_id) if result is None: - raise BadArgument('Emoji "{}" not found.'.format(self.argument)) + raise BadArgument('Emoji "{}" not found.'.format(argument)) return result @@ -266,8 +276,9 @@ class clean_content(Converter): self.fix_channel_mentions = fix_channel_mentions self.use_nicknames = use_nicknames - def convert(self): - message = self.ctx.message + @asyncio.coroutine + def convert(self, ctx, argument): + message = ctx.message transformations = {} if self.fix_channel_mentions: @@ -306,7 +317,7 @@ class clean_content(Converter): return transformations.get(obj.group(0), '') pattern = re.compile('|'.join(transformations.keys())) - result = pattern.sub(repl, self.argument) + result = pattern.sub(repl, argument) transformations = { '@everyone': '@\u200beveryone', diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 482238951..0b05a1aef 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -202,13 +202,10 @@ class Command: if inspect.isclass(converter) and issubclass(converter, converters.Converter): instance = converter() - instance.prepare(ctx, argument) - ret = yield from discord.utils.maybe_coroutine(instance.convert) + ret = yield from instance.convert(ctx, argument) return ret - - if isinstance(converter, converters.Converter): - converter.prepare(ctx, argument) - ret = yield from discord.utils.maybe_coroutine(converter.convert) + elif isinstance(converter, converters.Converter): + ret = yield from converter.convert(ctx, argument) return ret return converter(argument) @@ -220,9 +217,6 @@ class Command: converter = str if param.default is None else type(param.default) else: converter = str - elif not inspect.isclass(type(converter)): - raise discord.ClientException('Function annotation must be a type') - return converter @asyncio.coroutine