Browse Source

[commands] Converter.convert is always a coroutine.

Along with this change comes with the removal of Converter.prepare and
adding two arguments to Converter.convert, the context and the argument.

I suppose an added benefit is that you don't have to do attribute
access since it's a local variable.
pull/572/head
Rapptz 8 years ago
parent
commit
d7478425ca
  1. 149
      discord/ext/commands/converter.py
  2. 12
      discord/ext/commands/core.py

149
discord/ext/commands/converter.py

@ -53,46 +53,52 @@ class Converter:
special cased ``discord`` classes. special cased ``discord`` classes.
Classes that derive from this should override the :meth:`convert` method Classes that derive from this should override the :meth:`convert` method
to do its conversion logic. This method could be a coroutine or a regular to do its conversion logic. This method must be a coroutine.
function.
Before the convert method is called, :meth:`prepare` is called. This
method must set the attributes below if overwritten.
Attributes
-----------
ctx: :class:`Context`
The invocation context that the argument is being used in.
argument: str
The argument that is being converted.
""" """
def prepare(self, ctx, argument):
self.ctx = ctx
self.argument = argument
def convert(self): @asyncio.coroutine
def convert(self, ctx, argument):
"""|coro|
The method to override to do conversion logic.
This can either be a coroutine or a regular function.
If an error is found while converting, it is recommended to
raise a :class:`CommandError` derived exception as it will
properly propagate to the error handlers.
Parameters
-----------
ctx: :class:`Context`
The invocation context that the argument is being used in.
argument: str
The argument that is being converted.
"""
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError('Derived classes need to implement this.')
class IDConverter(Converter): class IDConverter(Converter):
def __init__(self): def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,21})$') self._id_regex = re.compile(r'([0-9]{15,21})$')
super().__init__()
def _get_id_match(self): def _get_id_match(self, argument):
return self._id_regex.match(self.argument) return self._id_regex.match(argument)
class MemberConverter(IDConverter): class MemberConverter(IDConverter):
def convert(self): @asyncio.coroutine
message = self.ctx.message def convert(self, ctx, argument):
bot = self.ctx.bot message = ctx.message
match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument) bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
guild = message.guild guild = message.guild
result = None result = None
if match is None: if match is None:
# not a mention... # not a mention...
if guild: if guild:
result = guild.get_member_named(self.argument) result = guild.get_member_named(argument)
else: else:
result = _get_from_guilds(bot, 'get_member_named', self.argument) result = _get_from_guilds(bot, 'get_member_named', argument)
else: else:
user_id = int(match.group(1)) user_id = int(match.group(1))
if guild: if guild:
@ -101,21 +107,22 @@ class MemberConverter(IDConverter):
result = _get_from_guilds(bot, 'get_member', user_id) result = _get_from_guilds(bot, 'get_member', user_id)
if result is None: if result is None:
raise BadArgument('Member "{}" not found'.format(self.argument)) raise BadArgument('Member "{}" not found'.format(argument))
return result return result
class UserConverter(IDConverter): class UserConverter(IDConverter):
def convert(self): @asyncio.coroutine
match = self._get_id_match() or re.match(r'<@!?([0-9]+)>$', self.argument) def convert(self, ctx, argument):
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
result = None result = None
state = self.ctx._state state = ctx._state
if match is not None: if match is not None:
user_id = int(match.group(1)) user_id = int(match.group(1))
result = self.ctx.bot.get_user(user_id) result = ctx.bot.get_user(user_id)
else: else:
arg = self.argument arg = argument
# check for discriminator if it exists # check for discriminator if it exists
if len(arg) > 5 and arg[-5] == '#': if len(arg) > 5 and arg[-5] == '#':
discrim = arg[-4:] discrim = arg[-4:]
@ -129,25 +136,26 @@ class UserConverter(IDConverter):
result = discord.utils.find(predicate, state._users.values()) result = discord.utils.find(predicate, state._users.values())
if result is None: if result is None:
raise BadArgument('User "{}" not found'.format(self.argument)) raise BadArgument('User "{}" not found'.format(argument))
return result return result
class TextChannelConverter(IDConverter): class TextChannelConverter(IDConverter):
def convert(self): @asyncio.coroutine
bot = self.ctx.bot def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument) match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None result = None
guild = self.ctx.guild guild = ctx.guild
if match is None: if match is None:
# not a mention # not a mention
if guild: if guild:
result = discord.utils.get(guild.text_channels, name=self.argument) result = discord.utils.get(guild.text_channels, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.TextChannel) and c.name == self.argument return isinstance(c, discord.TextChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -157,25 +165,25 @@ class TextChannelConverter(IDConverter):
result = _get_from_guilds(bot, 'get_channel', channel_id) result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.TextChannel): if not isinstance(result, discord.TextChannel):
raise BadArgument('Channel "{}" not found.'.format(self.argument)) raise BadArgument('Channel "{}" not found.'.format(argument))
return result return result
class VoiceChannelConverter(IDConverter): class VoiceChannelConverter(IDConverter):
def convert(self): @asyncio.coroutine
bot = self.ctx.bot def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match() or re.match(r'<#([0-9]+)>$', self.argument) match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None result = None
guild = self.ctx.guild guild = ctx.guild
if match is None: if match is None:
# not a mention # not a mention
if guild: if guild:
result = discord.utils.get(guild.voice_channels, name=self.argument) result = discord.utils.get(guild.voice_channels, name=argument)
else: else:
def check(c): def check(c):
return isinstance(c, discord.VoiceChannel) and c.name == self.argument return isinstance(c, discord.VoiceChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels())
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
@ -185,13 +193,14 @@ class VoiceChannelConverter(IDConverter):
result = _get_from_guilds(bot, 'get_channel', channel_id) result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.VoiceChannel): if not isinstance(result, discord.VoiceChannel):
raise BadArgument('Channel "{}" not found.'.format(self.argument)) raise BadArgument('Channel "{}" not found.'.format(argument))
return result return result
class ColourConverter(Converter): class ColourConverter(Converter):
def convert(self): @asyncio.coroutine
arg = self.argument.replace('0x', '').lower() def convert(self, ctx, argument):
arg = argument.replace('0x', '').lower()
if arg[0] == '#': if arg[0] == '#':
arg = arg[1:] arg = arg[1:]
@ -205,47 +214,48 @@ class ColourConverter(Converter):
return method() return method()
class RoleConverter(IDConverter): class RoleConverter(IDConverter):
def convert(self): @asyncio.coroutine
guild = self.ctx.message.guild def convert(self, ctx, argument):
guild = ctx.message.guild
if not guild: if not guild:
raise NoPrivateMessage() raise NoPrivateMessage()
match = self._get_id_match() or re.match(r'<@&([0-9]+)>$', self.argument) match = self._get_id_match(argument) or re.match(r'<@&([0-9]+)>$', argument)
params = dict(id=int(match.group(1))) if match else dict(name=self.argument) params = dict(id=int(match.group(1))) if match else dict(name=argument)
result = discord.utils.get(guild.roles, **params) result = discord.utils.get(guild.roles, **params)
if result is None: if result is None:
raise BadArgument('Role "{}" not found.'.format(self.argument)) raise BadArgument('Role "{}" not found.'.format(argument))
return result return result
class GameConverter(Converter): class GameConverter(Converter):
def convert(self): @asyncio.coroutine
return discord.Game(name=self.argument) def convert(self, ctx, argument):
return discord.Game(name=argument)
class InviteConverter(Converter): class InviteConverter(Converter):
@asyncio.coroutine @asyncio.coroutine
def convert(self): def convert(self, ctx, argument):
try: try:
invite = yield from self.ctx.bot.get_invite(self.argument) invite = yield from ctx.bot.get_invite(argument)
return invite return invite
except Exception as e: except Exception as e:
raise BadArgument('Invite is invalid or expired') from e raise BadArgument('Invite is invalid or expired') from e
class EmojiConverter(IDConverter): class EmojiConverter(IDConverter):
@asyncio.coroutine @asyncio.coroutine
def convert(self): def convert(self, ctx, argument):
message = self.ctx.message match = self._get_id_match(argument) or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', argument)
bot = self.ctx.bot
match = self._get_id_match() or re.match(r'<:[a-zA-Z0-9]+:([0-9]+)>$', self.argument)
result = None result = None
guild = message.guild bot = ctx.bot
guild = ctx.guild
if match is None: if match is None:
# Try to get the emoji by name. Try local guild first. # Try to get the emoji by name. Try local guild first.
if guild: if guild:
result = discord.utils.get(guild.emojis, name=self.argument) result = discord.utils.get(guild.emojis, name=argument)
if result is None: if result is None:
result = discord.utils.get(bot.emojis, name=self.argument) result = discord.utils.get(bot.emojis, name=argument)
else: else:
emoji_id = int(match.group(1)) emoji_id = int(match.group(1))
@ -257,7 +267,7 @@ class EmojiConverter(IDConverter):
result = discord.utils.get(bot.emojis, id=emoji_id) result = discord.utils.get(bot.emojis, id=emoji_id)
if result is None: if result is None:
raise BadArgument('Emoji "{}" not found.'.format(self.argument)) raise BadArgument('Emoji "{}" not found.'.format(argument))
return result return result
@ -266,8 +276,9 @@ class clean_content(Converter):
self.fix_channel_mentions = fix_channel_mentions self.fix_channel_mentions = fix_channel_mentions
self.use_nicknames = use_nicknames self.use_nicknames = use_nicknames
def convert(self): @asyncio.coroutine
message = self.ctx.message def convert(self, ctx, argument):
message = ctx.message
transformations = {} transformations = {}
if self.fix_channel_mentions: if self.fix_channel_mentions:
@ -306,7 +317,7 @@ class clean_content(Converter):
return transformations.get(obj.group(0), '') return transformations.get(obj.group(0), '')
pattern = re.compile('|'.join(transformations.keys())) pattern = re.compile('|'.join(transformations.keys()))
result = pattern.sub(repl, self.argument) result = pattern.sub(repl, argument)
transformations = { transformations = {
'@everyone': '@\u200beveryone', '@everyone': '@\u200beveryone',

12
discord/ext/commands/core.py

@ -202,13 +202,10 @@ class Command:
if inspect.isclass(converter) and issubclass(converter, converters.Converter): if inspect.isclass(converter) and issubclass(converter, converters.Converter):
instance = converter() instance = converter()
instance.prepare(ctx, argument) ret = yield from instance.convert(ctx, argument)
ret = yield from discord.utils.maybe_coroutine(instance.convert)
return ret return ret
elif isinstance(converter, converters.Converter):
if isinstance(converter, converters.Converter): ret = yield from converter.convert(ctx, argument)
converter.prepare(ctx, argument)
ret = yield from discord.utils.maybe_coroutine(converter.convert)
return ret return ret
return converter(argument) return converter(argument)
@ -220,9 +217,6 @@ class Command:
converter = str if param.default is None else type(param.default) converter = str if param.default is None else type(param.default)
else: else:
converter = str converter = str
elif not inspect.isclass(type(converter)):
raise discord.ClientException('Function annotation must be a type')
return converter return converter
@asyncio.coroutine @asyncio.coroutine

Loading…
Cancel
Save