Browse Source

Improvements to command processing

- Bot.get_commands_for_message is now more composable/externally useable
- Command parser for 'user' type has been improved to allow
username/discrim combos
- Model loading can now be done outside of the model constructor, and
supports some utility arguments
- Fix sub-model fields not having their default value be the sub-model
constructor
- Fix Message.without_mentions
- Add Message.with_proper_mentions (e.g. humanifying the message)
- Cleanup Message.replace_mentions for the above two changes
- Fix some weird casting inside MessageTable
pull/11/head
Andrei 9 years ago
parent
commit
dd75502b89
  1. 45
      disco/bot/bot.py
  2. 31
      disco/bot/command.py
  3. 27
      disco/types/base.py
  4. 58
      disco/types/message.py

45
disco/bot/bot.py

@ -221,7 +221,7 @@ class Bot(object):
else: else:
self.command_matches_re = None self.command_matches_re = None
def get_commands_for_message(self, msg): def get_commands_for_message(self, require_mention, mention_rules, prefix, msg):
""" """
Generator of all commands that a given message object triggers, based on Generator of all commands that a given message object triggers, based on
the bots plugins and configuration. the bots plugins and configuration.
@ -238,7 +238,7 @@ class Bot(object):
""" """
content = msg.content content = msg.content
if self.config.commands_require_mention: if require_mention:
mention_direct = msg.is_mentioned(self.client.state.me) mention_direct = msg.is_mentioned(self.client.state.me)
mention_everyone = msg.mention_everyone mention_everyone = msg.mention_everyone
@ -248,9 +248,9 @@ class Bot(object):
msg.guild.get_member(self.client.state.me).roles)) msg.guild.get_member(self.client.state.me).roles))
if not any(( if not any((
self.config.commands_mention_rules['user'] and mention_direct, mention_rules.get('user', True) and mention_direct,
self.config.commands_mention_rules['everyone'] and mention_everyone, mention_rules.get('everyone', False) and mention_everyone,
self.config.commands_mention_rules['role'] and any(mention_roles), mention_rules.get('role', False) and any(mention_roles),
msg.channel.is_dm msg.channel.is_dm
)): )):
raise StopIteration raise StopIteration
@ -270,10 +270,10 @@ class Bot(object):
content = content.lstrip() content = content.lstrip()
if self.config.commands_prefix and not content.startswith(self.config.commands_prefix): if prefix and not content.startswith(prefix):
raise StopIteration raise StopIteration
else: else:
content = content[len(self.config.commands_prefix):] content = content[len(prefix):]
if not self.command_matches_re or not self.command_matches_re.match(content): if not self.command_matches_re or not self.command_matches_re.match(content):
raise StopIteration raise StopIteration
@ -324,19 +324,24 @@ class Bot(object):
bool bool
whether any commands where successfully triggered by the message whether any commands where successfully triggered by the message
""" """
commands = list(self.get_commands_for_message(msg)) commands = list(self.get_commands_for_message(
self.config.commands_require_mention,
if len(commands): self.config.commands_mention_rules,
result = False self.config.commands_prefix,
for command, match in commands: msg
if not self.check_command_permissions(command, msg): ))
continue
if not len(commands):
if command.plugin.execute(CommandEvent(command, msg, match)): return False
result = True
return result result = False
for command, match in commands:
return False if not self.check_command_permissions(command, msg):
continue
if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
def on_message_create(self, event): def on_message_create(self, event):
if event.message.author.id == self.client.state.me.id: if event.message.author.id == self.client.state.me.id:

31
disco/bot/command.py

@ -140,11 +140,14 @@ class Command(object):
return ctx.msg.guild.roles.get(rid) return ctx.msg.guild.roles.get(rid)
def resolve_user(ctx, uid): def resolve_user(ctx, uid):
return ctx.msg.mentions.get(uid) if isinstance(uid, int):
return ctx.msg.mentions.get(uid)
else:
return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1])
self.args = ArgumentSet.from_string(args or '', { self.args = ArgumentSet.from_string(args or '', {
'mention': self.mention_type([resolve_role, resolve_user]), 'mention': self.mention_type([resolve_role, resolve_user]),
'user': self.mention_type([resolve_user], force=True), 'user': self.mention_type([resolve_user], force=True, user=True),
'role': self.mention_type([resolve_role], force=True), 'role': self.mention_type([resolve_role], force=True),
}) })
@ -156,27 +159,29 @@ class Command(object):
self.dispatch_func = dispatch_func self.dispatch_func = dispatch_func
@staticmethod @staticmethod
def mention_type(getters, force=False): def mention_type(getters, force=False, user=False):
def _f(ctx, i): def _f(ctx, raw):
# TODO: support full discrim format? make this betteR? if raw.isdigit():
if i.isdigit(): resolved = int(raw)
mid = int(i) elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit():
username, discrim = raw.split('#')
resolved = (username, int(discrim))
else: else:
res = MENTION_RE.match(i) res = MENTION_RE.match(raw)
if not res: if not res:
raise TypeError('Invalid mention: {}'.format(i)) raise TypeError('Invalid mention: {}'.format(raw))
mid = int(res.group(1)) resolved = int(res.group(1))
for getter in getters: for getter in getters:
obj = getter(ctx, mid) obj = getter(ctx, resolved)
if obj: if obj:
return obj return obj
if force: if force:
raise TypeError('Cannot resolve mention: {}'.format(id)) raise TypeError('Cannot resolve mention: {}'.format(raw))
return mid return resolved
return _f return _f
@cached_property @cached_property

27
disco/types/base.py

@ -34,9 +34,10 @@ class ConversionError(Exception):
class Field(object): class Field(object):
def __init__(self, value_type, alias=None, default=None): def __init__(self, value_type, alias=None, default=None, **kwargs):
self.src_name = alias self.src_name = alias
self.dst_name = None self.dst_name = None
self.metadata = kwargs
if default is not None: if default is not None:
self.default = default self.default = default
@ -50,6 +51,8 @@ class Field(object):
if isinstance(self.deserializer, Field) and self.default is None: if isinstance(self.deserializer, Field) and self.default is None:
self.default = self.deserializer.default self.default = self.deserializer.default
elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None:
self.default = self.deserializer
@property @property
def name(self): def name(self):
@ -275,8 +278,22 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
else: else:
obj = kwargs obj = kwargs
for name, field in six.iteritems(self.__class__._fields): self.load(obj)
if field.src_name not in obj or obj[field.src_name] is None:
@property
def fields(self):
return self.__class__._fields
def load(self, obj, consume=False, skip=None):
for name, field in six.iteritems(self.fields):
should_skip = skip and name in skip
if consume and not should_skip:
raw = obj.pop(field.src_name, None)
else:
raw = obj.get(field.src_name, None)
if raw is None or should_skip:
if field.has_default(): if field.has_default():
default = field.default() if callable(field.default) else field.default default = field.default() if callable(field.default) else field.default
else: else:
@ -284,11 +301,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
setattr(self, field.dst_name, default) setattr(self, field.dst_name, default)
continue continue
value = field.try_convert(obj[field.src_name], self.client) value = field.try_convert(raw, self.client)
setattr(self, field.dst_name, value) setattr(self, field.dst_name, value)
def update(self, other): def update(self, other):
for name in six.iterkeys(self.__class__._fields): for name in six.iterkeys(self.fields):
if hasattr(other, name) and not getattr(other, name) is UNSET: if hasattr(other, name) and not getattr(other, name) is UNSET:
setattr(self, name, getattr(other, name)) setattr(self, name, getattr(other, name))

58
disco/types/message.py

@ -1,4 +1,5 @@
import re import re
import functools
from holster.enum import Enum from holster.enum import Enum
@ -310,18 +311,33 @@ class Message(SlottedModel):
return entity in self.mentions or entity in self.mention_roles return entity in self.mentions or entity in self.mention_roles
@cached_property @cached_property
def without_mentions(self): def without_mentions(self, valid_only=False):
""" """
Returns Returns
------- -------
str str
the message contents with all valid mentions removed. the message contents with all mentions removed.
""" """
return self.replace_mentions( return self.replace_mentions(
lambda u: '', lambda u: '',
lambda r: '') lambda r: '',
lambda c: '',
nonexistant=not valid_only)
def replace_mentions(self, user_replace, role_replace): @cached_property
def with_proper_mentions(self):
def replace_user(u):
return '@' + str(u)
def replace_role(r):
return '@' + str(r)
def replace_channel(c):
return str(c)
return self.replace_mentions(replace_user, replace_role, replace_channel)
def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False):
""" """
Replaces user and role mentions with the result of a given lambda/function. Replaces user and role mentions with the result of a given lambda/function.
@ -339,17 +355,30 @@ class Message(SlottedModel):
str str
The message contents with all valid mentions replaced. The message contents with all valid mentions replaced.
""" """
if not self.mentions and not self.mention_roles: def replace(getter, func, match):
return oid = int(match.group(2))
obj = getter(oid)
if obj or nonexistant:
return func(obj or oid) or match.group(0)
return match.group(0)
content = self.content
if user_replace:
replace_user = functools.partial(replace, self.mentions.get, user_replace)
content = re.sub('(<@!?([0-9]+)>)', replace_user, self.content)
if role_replace:
replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace)
content = re.sub('(<@&([0-9]+)>)', replace_role, content)
def replace(match): if channel_replace:
oid = match.group(0) replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace)
if oid in self.mention_roles: content = re.sub('(<#([0-9]+)>)', replace_channel, content)
return role_replace(oid)
else:
return user_replace(self.mentions.get(oid))
return re.sub('<@!?([0-9]+)>', replace, self.content) return content
class MessageTable(object): class MessageTable(object):
@ -371,8 +400,7 @@ class MessageTable(object):
self.recalculate_size_index(args) self.recalculate_size_index(args)
def add(self, *args): def add(self, *args):
convert = lambda v: v if isinstance(v, basestring) else str(v) args = list(map(lambda v: v if isinstance(v, basestring) else str(v), args))
args = list(map(convert, args))
self.entries.append(args) self.entries.append(args)
self.recalculate_size_index(args) self.recalculate_size_index(args)

Loading…
Cancel
Save