From dd75502b89a5ef629182a2493101d9fac5e9a810 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 12 Nov 2016 07:22:19 -0600 Subject: [PATCH] Improvements to command processing - Bot.get_commands_for_message is now more composable/externally useable - Command parser for 'user' type has been improved to allow username/discrim combos - Model loading can now be done outside of the model constructor, and supports some utility arguments - Fix sub-model fields not having their default value be the sub-model constructor - Fix Message.without_mentions - Add Message.with_proper_mentions (e.g. humanifying the message) - Cleanup Message.replace_mentions for the above two changes - Fix some weird casting inside MessageTable --- disco/bot/bot.py | 45 +++++++++++++++++--------------- disco/bot/command.py | 31 ++++++++++++---------- disco/types/base.py | 27 ++++++++++++++++---- disco/types/message.py | 58 +++++++++++++++++++++++++++++++----------- 4 files changed, 108 insertions(+), 53 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index ec670e1..2f97832 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -221,7 +221,7 @@ class Bot(object): else: self.command_matches_re = None - def get_commands_for_message(self, msg): + def get_commands_for_message(self, require_mention, mention_rules, prefix, msg): """ Generator of all commands that a given message object triggers, based on the bots plugins and configuration. @@ -238,7 +238,7 @@ class Bot(object): """ content = msg.content - if self.config.commands_require_mention: + if require_mention: mention_direct = msg.is_mentioned(self.client.state.me) mention_everyone = msg.mention_everyone @@ -248,9 +248,9 @@ class Bot(object): msg.guild.get_member(self.client.state.me).roles)) if not any(( - self.config.commands_mention_rules['user'] and mention_direct, - self.config.commands_mention_rules['everyone'] and mention_everyone, - self.config.commands_mention_rules['role'] and any(mention_roles), + mention_rules.get('user', True) and mention_direct, + mention_rules.get('everyone', False) and mention_everyone, + mention_rules.get('role', False) and any(mention_roles), msg.channel.is_dm )): raise StopIteration @@ -270,10 +270,10 @@ class Bot(object): content = content.lstrip() - if self.config.commands_prefix and not content.startswith(self.config.commands_prefix): + if prefix and not content.startswith(prefix): raise StopIteration else: - content = content[len(self.config.commands_prefix):] + content = content[len(prefix):] if not self.command_matches_re or not self.command_matches_re.match(content): raise StopIteration @@ -324,19 +324,24 @@ class Bot(object): bool whether any commands where successfully triggered by the message """ - commands = list(self.get_commands_for_message(msg)) - - if len(commands): - result = False - for command, match in commands: - if not self.check_command_permissions(command, msg): - continue - - if command.plugin.execute(CommandEvent(command, msg, match)): - result = True - return result - - return False + commands = list(self.get_commands_for_message( + self.config.commands_require_mention, + self.config.commands_mention_rules, + self.config.commands_prefix, + msg + )) + + if not len(commands): + return False + + result = False + for command, match in commands: + if not self.check_command_permissions(command, msg): + continue + + if command.plugin.execute(CommandEvent(command, msg, match)): + result = True + return result def on_message_create(self, event): if event.message.author.id == self.client.state.me.id: diff --git a/disco/bot/command.py b/disco/bot/command.py index 40ccbea..08c9b9f 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -140,11 +140,14 @@ class Command(object): return ctx.msg.guild.roles.get(rid) def resolve_user(ctx, uid): - return ctx.msg.mentions.get(uid) + if isinstance(uid, int): + return ctx.msg.mentions.get(uid) + else: + return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1]) self.args = ArgumentSet.from_string(args or '', { 'mention': self.mention_type([resolve_role, resolve_user]), - 'user': self.mention_type([resolve_user], force=True), + 'user': self.mention_type([resolve_user], force=True, user=True), 'role': self.mention_type([resolve_role], force=True), }) @@ -156,27 +159,29 @@ class Command(object): self.dispatch_func = dispatch_func @staticmethod - def mention_type(getters, force=False): - def _f(ctx, i): - # TODO: support full discrim format? make this betteR? - if i.isdigit(): - mid = int(i) + def mention_type(getters, force=False, user=False): + def _f(ctx, raw): + if raw.isdigit(): + resolved = int(raw) + elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit(): + username, discrim = raw.split('#') + resolved = (username, int(discrim)) else: - res = MENTION_RE.match(i) + res = MENTION_RE.match(raw) if not res: - raise TypeError('Invalid mention: {}'.format(i)) + raise TypeError('Invalid mention: {}'.format(raw)) - mid = int(res.group(1)) + resolved = int(res.group(1)) for getter in getters: - obj = getter(ctx, mid) + obj = getter(ctx, resolved) if obj: return obj if force: - raise TypeError('Cannot resolve mention: {}'.format(id)) + raise TypeError('Cannot resolve mention: {}'.format(raw)) - return mid + return resolved return _f @cached_property diff --git a/disco/types/base.py b/disco/types/base.py index 1013906..56657ff 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -34,9 +34,10 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None): + def __init__(self, value_type, alias=None, default=None, **kwargs): self.src_name = alias self.dst_name = None + self.metadata = kwargs if default is not None: self.default = default @@ -50,6 +51,8 @@ class Field(object): if isinstance(self.deserializer, Field) and self.default is None: self.default = self.deserializer.default + elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None: + self.default = self.deserializer @property def name(self): @@ -275,8 +278,22 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): else: obj = kwargs - for name, field in six.iteritems(self.__class__._fields): - if field.src_name not in obj or obj[field.src_name] is None: + self.load(obj) + + @property + def fields(self): + return self.__class__._fields + + def load(self, obj, consume=False, skip=None): + for name, field in six.iteritems(self.fields): + should_skip = skip and name in skip + + if consume and not should_skip: + raw = obj.pop(field.src_name, None) + else: + raw = obj.get(field.src_name, None) + + if raw is None or should_skip: if field.has_default(): default = field.default() if callable(field.default) else field.default else: @@ -284,11 +301,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): setattr(self, field.dst_name, default) continue - value = field.try_convert(obj[field.src_name], self.client) + value = field.try_convert(raw, self.client) setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self.__class__._fields): + for name in six.iterkeys(self.fields): if hasattr(other, name) and not getattr(other, name) is UNSET: setattr(self, name, getattr(other, name)) diff --git a/disco/types/message.py b/disco/types/message.py index b5eaf17..e7b2873 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,4 +1,5 @@ import re +import functools from holster.enum import Enum @@ -310,18 +311,33 @@ class Message(SlottedModel): return entity in self.mentions or entity in self.mention_roles @cached_property - def without_mentions(self): + def without_mentions(self, valid_only=False): """ Returns ------- str - the message contents with all valid mentions removed. + the message contents with all mentions removed. """ return self.replace_mentions( lambda u: '', - lambda r: '') + lambda r: '', + lambda c: '', + nonexistant=not valid_only) - def replace_mentions(self, user_replace, role_replace): + @cached_property + def with_proper_mentions(self): + def replace_user(u): + return '@' + str(u) + + def replace_role(r): + return '@' + str(r) + + def replace_channel(c): + return str(c) + + return self.replace_mentions(replace_user, replace_role, replace_channel) + + def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False): """ Replaces user and role mentions with the result of a given lambda/function. @@ -339,17 +355,30 @@ class Message(SlottedModel): str The message contents with all valid mentions replaced. """ - if not self.mentions and not self.mention_roles: - return + def replace(getter, func, match): + oid = int(match.group(2)) + obj = getter(oid) + + if obj or nonexistant: + return func(obj or oid) or match.group(0) + + return match.group(0) + + content = self.content + + if user_replace: + replace_user = functools.partial(replace, self.mentions.get, user_replace) + content = re.sub('(<@!?([0-9]+)>)', replace_user, self.content) + + if role_replace: + replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace) + content = re.sub('(<@&([0-9]+)>)', replace_role, content) - def replace(match): - oid = match.group(0) - if oid in self.mention_roles: - return role_replace(oid) - else: - return user_replace(self.mentions.get(oid)) + if channel_replace: + replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace) + content = re.sub('(<#([0-9]+)>)', replace_channel, content) - return re.sub('<@!?([0-9]+)>', replace, self.content) + return content class MessageTable(object): @@ -371,8 +400,7 @@ class MessageTable(object): self.recalculate_size_index(args) def add(self, *args): - convert = lambda v: v if isinstance(v, basestring) else str(v) - args = list(map(convert, args)) + args = list(map(lambda v: v if isinstance(v, basestring) else str(v), args)) self.entries.append(args) self.recalculate_size_index(args)