From 3d3d57a00c6253e482d0d8d374a8f7da231a7809 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 23 Sep 2016 04:29:02 -0500 Subject: [PATCH] Add argument parsing --- disco/bot/bot.py | 6 ++- disco/bot/command.py | 36 ++++++++++--- disco/bot/parser.py | 112 +++++++++++++++++++++++++++++++++++++++ disco/bot/plugin.py | 17 ++++-- examples/basic_plugin.py | 4 +- 5 files changed, 159 insertions(+), 16 deletions(-) create mode 100644 disco/bot/parser.py diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 5c996d0..bce1fe1 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -1,6 +1,7 @@ import re from disco.client import DiscoClient +from disco.bot.command import CommandEvent class BotConfig(object): @@ -97,7 +98,10 @@ class Bot(object): commands = list(self.get_commands_for_message(msg)) if len(commands): - return any((command.execute(msg, match) for command, match in commands)) + return any([ + command.plugin.execute(CommandEvent(command, msg, match)) + for command, match in commands + ]) return False diff --git a/disco/bot/command.py b/disco/bot/command.py index 37fed7c..4648e05 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -1,29 +1,49 @@ import re +from disco.bot.parser import parse_arguments, ArgumentError from disco.util.cache import cached_property +REGEX_FMT = '({})' ARGS_REGEX = '( (.*)$|$)' class CommandEvent(object): - def __init__(self, msg, match): + def __init__(self, command, msg, match): + self.command = command self.msg = msg self.match = match - self.args = self.match.group(1).strip().split(' ') + self.name = self.match.group(1) + self.args = self.match.group(2).strip().split(' ') + + +class CommandError(Exception): + pass class Command(object): - def __init__(self, plugin, func, trigger, aliases=None, group=None, is_regex=False): + def __init__(self, plugin, func, trigger, args=None, aliases=None, group=None, is_regex=False): self.plugin = plugin self.func = func self.triggers = [trigger] + (aliases or []) + self.args = parse_arguments(args or '') self.group = group self.is_regex = is_regex - def execute(self, msg, match): - event = CommandEvent(msg, match) - return self.func(event) + def execute(self, event): + if len(event.args) < self.args.required_length: + raise CommandError('{} requires {} arguments (passed {})'.format( + event.name, + self.args.required_length, + len(event.args) + )) + + try: + args = self.args.parse(event.args) + except ArgumentError as e: + raise CommandError(e.message) + + return self.func(event, *args) @cached_property def compiled_regex(self): @@ -32,7 +52,7 @@ class Command(object): @property def regex(self): if self.is_regex: - return '|'.join(self.triggers) + return REGEX_FMT.format('|'.join(self.triggers)) else: group = self.group + ' ' if self.group else '' - return '|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX + return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) diff --git a/disco/bot/parser.py b/disco/bot/parser.py new file mode 100644 index 0000000..92671db --- /dev/null +++ b/disco/bot/parser.py @@ -0,0 +1,112 @@ +import re + +PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])') + +TYPE_MAP = { + 'str': str, + 'int': int, + 'float': float, + 'snowflake': int, +} + + +class ArgumentError(Exception): + pass + + +class Argument(object): + def __init__(self, raw): + self.name = None + self.count = 1 + self.required = False + self.types = None + self.parse(raw) + + @property + def true_count(self): + return self.count or 1 + + def convert(self, obj): + for typ in self.types: + typ = TYPE_MAP.get(typ) + try: + return typ(obj) + except Exception as e: + continue + raise e + + def parse(self, raw): + prefix, part = raw + + if prefix == '<': + self.required = True + else: + self.required = False + + if part.endswith('...'): + part = part[:-3] + self.count = 0 + elif ' ' in part: + split = part.split(' ', 1) + part, self.count = split[0], int(split[1]) + + if ':' in part: + part, typeinfo = part.split(':') + self.types = typeinfo.split('|') + + self.name = part.strip() + + +class ArgumentSet(object): + def __init__(self, args=None): + self.args = args or [] + + def append(self, arg): + if self.args and not self.args[-1].required and arg.required: + raise Exception('Required argument cannot come after an optional argument') + + if self.args and not self.args[-1].count: + raise Exception('No arguments can come after a catch-all') + + self.args.append(arg) + + def parse(self, rawargs): + parsed = [] + + for index, arg in enumerate(self.args): + if not arg.required and index + arg.true_count <= len(rawargs): + continue + + raw = rawargs[index:index + arg.true_count] + + if arg.types: + for idx, r in enumerate(raw): + try: + raw[idx] = arg.convert(r) + except: + raise ArgumentError('cannot convert `{}` to `{}`'.format( + r, ', '.join(arg.types) + )) + + parsed.append(raw) + + return parsed + + @property + def length(self): + return len(self.args) + + @property + def required_length(self): + return sum([i.true_count for i in self.args if i.required]) + + +def parse_arguments(line): + args = ArgumentSet() + + data = PARTS_RE.findall(line) + if len(data): + for item in data: + args.append(Argument(item)) + + return args diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 370eab6..6c6811e 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,7 +1,7 @@ import inspect import functools -from disco.bot.command import Command +from disco.bot.command import Command, CommandError class PluginDeco(object): @@ -78,20 +78,27 @@ class Plugin(PluginDeco): when, typ = meta['type'].split('_', 1) self.register_trigger(typ, when, member) + def execute(self, event): + try: + return event.command.execute(event) + except CommandError as e: + event.msg.reply(e.message) + return False + def register_trigger(self, typ, when, func): getattr(self, '_' + when)[typ].append(func) - def _dispatch(self, typ, func, event): + def _dispatch(self, typ, func, event, *args, **kwargs): for pre in self._pre[typ]: - event = pre(event) + event = pre(event, args, kwargs) if event is None: return False - result = func(event) + result = func(event, *args, **kwargs) for post in self._post[typ]: - post(event, result) + post(event, args, kwargs, result) return True diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 2d5df1d..df5397c 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -12,8 +12,8 @@ class BasicPlugin(Plugin): def on_test_command(self, event): event.msg.reply('HELLO WORLD') - @Plugin.command('spam') - def on_spam_command(self, event): + @Plugin.command('spam', ' ') + def on_spam_command(self, event, count, content): count = int(event.args[0]) for i in range(count):