diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py index 0bf1590e1..6d82962de 100644 --- a/discord/ext/commands/__init__.py +++ b/discord/ext/commands/__init__.py @@ -15,3 +15,4 @@ from .context import Context from .core import * from .errors import * from .formatter import HelpFormatter +from .converter import * diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py new file mode 100644 index 000000000..4f34956b4 --- /dev/null +++ b/discord/ext/commands/converter.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2016 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import discord +import asyncio +import re + +from .errors import BadArgument, NoPrivateMessage + +__all__ = [ 'Converter', 'MemberConverter', 'UserConverter', + 'ChannelConverter', 'InviteConverter', 'RoleConverter', + 'GameConverter', 'ColourConverter' ] + +class Converter: + """The base class of custom converters that require the :class:`Context` + to be passed to be useful. + + This allows you to implement converters that function similar to the + special cased ``discord`` classes. + + Classes that derive from this should override the :meth:`convert` method + to do its conversion logic. This method could be a coroutine or a regular + function. + + Attributes + ----------- + ctx: :class:`Context` + The invocation context that the argument is being used in. + argument: str + The argument that is being converted. + """ + def __init__(self, ctx, argument): + self.ctx = ctx + self.argument = argument + + def convert(self): + raise NotImplementedError('Derived classes need to implement this.') + +class MemberConverter(Converter): + def convert(self): + message = self.ctx.message + bot = self.ctx.bot + + match = re.match(r'<@!?([0-9]+)>$', self.argument) + server = message.server + result = None + if match is None: + # not a mention... + if server: + result = server.get_member_named(self.argument) + else: + result = self._get_from_servers(bot, 'get_member_named', self.argument) + else: + user_id = match.group(1) + if server: + result = server.get_member(user_id) + else: + result = self._get_from_servers(bot, 'get_member', user_id) + + if result is None: + raise BadArgument('Member "{}" not found'.format(self.argument)) + + return result + +UserConverter = MemberConverter + +class ChannelConverter(Converter): + def convert(self): + message = self.ctx.message + bot = self.ctx.bot + + match = re.match(r'<#([0-9]+)>$', self.argument) + result = None + server = message.server + if match is None: + # not a mention + if server: + result = discord.utils.get(server.channels, name=self.argument) + else: + result = discord.utils.get(bot.get_all_channels(), name=self.argument) + else: + channel_id = match.group(1) + if server: + result = server.get_channel(channel_id) + else: + result = self._get_from_servers(bot, 'get_channel', channel_id) + + if result is None: + raise BadArgument('Channel "{}" not found.'.format(self.argument)) + + return result + +class ColourConverter(Converter): + def convert(self): + arg = self.argument.replace('0x', '').lower() + + if arg[0] == '#': + arg = arg[1:] + try: + value = int(arg, base=16) + return discord.Colour(value=value) + except ValueError: + method = getattr(discord.Colour, arg, None) + if method is None or not inspect.ismethod(method): + raise BadArgument('Colour "{}" is invalid.'.format(arg)) + return method() + +class RoleConverter(Converter): + def convert(self): + server = self.ctx.message.server + if not server: + raise NoPrivateMessage() + + match = re.match(r'<@&([0-9]+)>$', self.argument) + params = dict(id=match.group(1)) if match else dict(name=self.argument) + result = discord.utils.get(server.roles, **params) + if result is None: + raise BadArgument('Role "{}" not found.'.format(self.argument)) + return result + +class GameConverter(Converter): + def convert(self): + return discord.Game(name=self.argument) + +class InviteConverter(Converter): + @asyncio.coroutine + def convert(self): + try: + invite = yield from self.ctx.bot.get_invite(self.argument) + return invite + except Exception as e: + raise BadArgument('Invite is invalid or expired') from e diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index e28116508..18a2d9d43 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -26,17 +26,16 @@ DEALINGS IN THE SOFTWARE. import asyncio import inspect -import re import discord import functools from .errors import * from .view import quoted_word +from . import converter as converters __all__ = [ 'Command', 'Group', 'GroupMixin', 'command', 'group', 'has_role', 'has_permissions', 'has_any_role', 'check', - 'bot_has_role', 'bot_has_permissions', 'bot_has_any_role', - 'Converter' ] + 'bot_has_role', 'bot_has_permissions', 'bot_has_any_role' ] def inject_context(ctx, coro): @functools.wraps(coro) @@ -61,31 +60,6 @@ def _convert_to_bool(argument): else: raise BadArgument(lowered + ' is not a recognised boolean option') -class Converter: - """The base class of custom converters that require the :class:`Context` - to be passed to be useful. - - This allows you to implement converters that function similar to the - special cased ``discord`` classes. - - Classes that derive from this should override the :meth:`convert` method - to do its conversion logic. This method could be a coroutine or a regular - function. - - Attributes - ----------- - ctx: :class:`Context` - The invocation context that the argument is being used in. - argument: str - The argument that is being converted. - """ - def __init__(self, ctx, argument): - self.ctx = ctx - self.argument = argument - - def convert(self): - raise NotImplementedError('Derived classes need to implement this.') - class Command: """A class that implements the protocol for a bot text command. @@ -192,105 +166,22 @@ class Command: return result return result - def _convert_member(self, bot, message, argument): - match = re.match(r'<@!?([0-9]+)>$', argument) - server = message.server - result = None - if match is None: - # not a mention... - if server: - result = server.get_member_named(argument) - else: - result = self._get_from_servers(bot, 'get_member_named', argument) - else: - user_id = match.group(1) - if server: - result = server.get_member(user_id) - else: - result = self._get_from_servers(bot, 'get_member', user_id) - - if result is None: - raise BadArgument('Member "{}" not found'.format(argument)) - - return result - - _convert_user = _convert_member - - def _convert_channel(self, bot, message, argument): - match = re.match(r'<#([0-9]+)>$', argument) - result = None - server = message.server - if match is None: - # not a mention - if server: - result = discord.utils.get(server.channels, name=argument) - else: - result = discord.utils.get(bot.get_all_channels(), name=argument) - else: - channel_id = match.group(1) - if server: - result = server.get_channel(channel_id) - else: - result = self._get_from_servers(bot, 'get_channel', channel_id) - - if result is None: - raise BadArgument('Channel "{}" not found.'.format(argument)) - - return result - - def _convert_colour(self, bot, message, argument): - arg = argument.replace('0x', '').lower() - if arg[0] == '#': - arg = arg[1:] - try: - value = int(arg, base=16) - return discord.Colour(value=value) - except ValueError: - method = getattr(discord.Colour, arg, None) - if method is None or not inspect.ismethod(method): - raise BadArgument('Colour "{}" is invalid.'.format(arg)) - return method() - - def _convert_role(self, bot, message, argument): - server = message.server - if not server: - raise NoPrivateMessage() - - match = re.match(r'<@&([0-9]+)>$', argument) - params = dict(id=match.group(1)) if match else dict(name=argument) - result = discord.utils.get(server.roles, **params) - if result is None: - raise BadArgument('Role "{}" not found.'.format(argument)) - return result - - def _convert_game(self, bot, message, argument): - return discord.Game(name=argument) - @asyncio.coroutine def do_conversion(self, ctx, converter, argument): if converter is bool: return _convert_to_bool(argument) - if issubclass(converter, Converter): + if converter.__module__.startswith('discord.'): + converter = getattr(converters, converter.__name__ + 'Converter') + + if issubclass(converter, converters.Converter): instance = converter(ctx, argument) if asyncio.iscoroutinefunction(instance.convert): return (yield from instance.convert()) else: return instance.convert() - if converter.__module__.split('.')[0] != 'discord': - return converter(argument) - - # special handling for discord.py related classes - if converter is discord.Invite: - try: - invite = yield from ctx.bot.get_invite(argument) - return invite - except Exception as e: - raise BadArgument('Invite is invalid or expired') from e - - new_converter = getattr(self, '_convert_{}'.format(converter.__name__.lower())) - return new_converter(ctx.bot, ctx.message, argument) + return converter(argument) def _get_converter(self, param): converter = param.annotation