Browse Source

Improvements to command processing

- Bot.get_commands_for_message is now more composable/externally useable
- Command parser for 'user' type has been improved to allow
username/discrim combos
- Model loading can now be done outside of the model constructor, and
supports some utility arguments
- Fix sub-model fields not having their default value be the sub-model
constructor
- Fix Message.without_mentions
- Add Message.with_proper_mentions (e.g. humanifying the message)
- Cleanup Message.replace_mentions for the above two changes
- Fix some weird casting inside MessageTable
pull/11/head
Andrei 9 years ago
parent
commit
dd75502b89
  1. 45
      disco/bot/bot.py
  2. 31
      disco/bot/command.py
  3. 27
      disco/types/base.py
  4. 58
      disco/types/message.py

45
disco/bot/bot.py

@ -221,7 +221,7 @@ class Bot(object):
else:
self.command_matches_re = None
def get_commands_for_message(self, msg):
def get_commands_for_message(self, require_mention, mention_rules, prefix, msg):
"""
Generator of all commands that a given message object triggers, based on
the bots plugins and configuration.
@ -238,7 +238,7 @@ class Bot(object):
"""
content = msg.content
if self.config.commands_require_mention:
if require_mention:
mention_direct = msg.is_mentioned(self.client.state.me)
mention_everyone = msg.mention_everyone
@ -248,9 +248,9 @@ class Bot(object):
msg.guild.get_member(self.client.state.me).roles))
if not any((
self.config.commands_mention_rules['user'] and mention_direct,
self.config.commands_mention_rules['everyone'] and mention_everyone,
self.config.commands_mention_rules['role'] and any(mention_roles),
mention_rules.get('user', True) and mention_direct,
mention_rules.get('everyone', False) and mention_everyone,
mention_rules.get('role', False) and any(mention_roles),
msg.channel.is_dm
)):
raise StopIteration
@ -270,10 +270,10 @@ class Bot(object):
content = content.lstrip()
if self.config.commands_prefix and not content.startswith(self.config.commands_prefix):
if prefix and not content.startswith(prefix):
raise StopIteration
else:
content = content[len(self.config.commands_prefix):]
content = content[len(prefix):]
if not self.command_matches_re or not self.command_matches_re.match(content):
raise StopIteration
@ -324,19 +324,24 @@ class Bot(object):
bool
whether any commands where successfully triggered by the message
"""
commands = list(self.get_commands_for_message(msg))
if len(commands):
result = False
for command, match in commands:
if not self.check_command_permissions(command, msg):
continue
if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
return False
commands = list(self.get_commands_for_message(
self.config.commands_require_mention,
self.config.commands_mention_rules,
self.config.commands_prefix,
msg
))
if not len(commands):
return False
result = False
for command, match in commands:
if not self.check_command_permissions(command, msg):
continue
if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
def on_message_create(self, event):
if event.message.author.id == self.client.state.me.id:

31
disco/bot/command.py

@ -140,11 +140,14 @@ class Command(object):
return ctx.msg.guild.roles.get(rid)
def resolve_user(ctx, uid):
return ctx.msg.mentions.get(uid)
if isinstance(uid, int):
return ctx.msg.mentions.get(uid)
else:
return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1])
self.args = ArgumentSet.from_string(args or '', {
'mention': self.mention_type([resolve_role, resolve_user]),
'user': self.mention_type([resolve_user], force=True),
'user': self.mention_type([resolve_user], force=True, user=True),
'role': self.mention_type([resolve_role], force=True),
})
@ -156,27 +159,29 @@ class Command(object):
self.dispatch_func = dispatch_func
@staticmethod
def mention_type(getters, force=False):
def _f(ctx, i):
# TODO: support full discrim format? make this betteR?
if i.isdigit():
mid = int(i)
def mention_type(getters, force=False, user=False):
def _f(ctx, raw):
if raw.isdigit():
resolved = int(raw)
elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit():
username, discrim = raw.split('#')
resolved = (username, int(discrim))
else:
res = MENTION_RE.match(i)
res = MENTION_RE.match(raw)
if not res:
raise TypeError('Invalid mention: {}'.format(i))
raise TypeError('Invalid mention: {}'.format(raw))
mid = int(res.group(1))
resolved = int(res.group(1))
for getter in getters:
obj = getter(ctx, mid)
obj = getter(ctx, resolved)
if obj:
return obj
if force:
raise TypeError('Cannot resolve mention: {}'.format(id))
raise TypeError('Cannot resolve mention: {}'.format(raw))
return mid
return resolved
return _f
@cached_property

27
disco/types/base.py

@ -34,9 +34,10 @@ class ConversionError(Exception):
class Field(object):
def __init__(self, value_type, alias=None, default=None):
def __init__(self, value_type, alias=None, default=None, **kwargs):
self.src_name = alias
self.dst_name = None
self.metadata = kwargs
if default is not None:
self.default = default
@ -50,6 +51,8 @@ class Field(object):
if isinstance(self.deserializer, Field) and self.default is None:
self.default = self.deserializer.default
elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None:
self.default = self.deserializer
@property
def name(self):
@ -275,8 +278,22 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
else:
obj = kwargs
for name, field in six.iteritems(self.__class__._fields):
if field.src_name not in obj or obj[field.src_name] is None:
self.load(obj)
@property
def fields(self):
return self.__class__._fields
def load(self, obj, consume=False, skip=None):
for name, field in six.iteritems(self.fields):
should_skip = skip and name in skip
if consume and not should_skip:
raw = obj.pop(field.src_name, None)
else:
raw = obj.get(field.src_name, None)
if raw is None or should_skip:
if field.has_default():
default = field.default() if callable(field.default) else field.default
else:
@ -284,11 +301,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
setattr(self, field.dst_name, default)
continue
value = field.try_convert(obj[field.src_name], self.client)
value = field.try_convert(raw, self.client)
setattr(self, field.dst_name, value)
def update(self, other):
for name in six.iterkeys(self.__class__._fields):
for name in six.iterkeys(self.fields):
if hasattr(other, name) and not getattr(other, name) is UNSET:
setattr(self, name, getattr(other, name))

58
disco/types/message.py

@ -1,4 +1,5 @@
import re
import functools
from holster.enum import Enum
@ -310,18 +311,33 @@ class Message(SlottedModel):
return entity in self.mentions or entity in self.mention_roles
@cached_property
def without_mentions(self):
def without_mentions(self, valid_only=False):
"""
Returns
-------
str
the message contents with all valid mentions removed.
the message contents with all mentions removed.
"""
return self.replace_mentions(
lambda u: '',
lambda r: '')
lambda r: '',
lambda c: '',
nonexistant=not valid_only)
def replace_mentions(self, user_replace, role_replace):
@cached_property
def with_proper_mentions(self):
def replace_user(u):
return '@' + str(u)
def replace_role(r):
return '@' + str(r)
def replace_channel(c):
return str(c)
return self.replace_mentions(replace_user, replace_role, replace_channel)
def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False):
"""
Replaces user and role mentions with the result of a given lambda/function.
@ -339,17 +355,30 @@ class Message(SlottedModel):
str
The message contents with all valid mentions replaced.
"""
if not self.mentions and not self.mention_roles:
return
def replace(getter, func, match):
oid = int(match.group(2))
obj = getter(oid)
if obj or nonexistant:
return func(obj or oid) or match.group(0)
return match.group(0)
content = self.content
if user_replace:
replace_user = functools.partial(replace, self.mentions.get, user_replace)
content = re.sub('(<@!?([0-9]+)>)', replace_user, self.content)
if role_replace:
replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace)
content = re.sub('(<@&([0-9]+)>)', replace_role, content)
def replace(match):
oid = match.group(0)
if oid in self.mention_roles:
return role_replace(oid)
else:
return user_replace(self.mentions.get(oid))
if channel_replace:
replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace)
content = re.sub('(<#([0-9]+)>)', replace_channel, content)
return re.sub('<@!?([0-9]+)>', replace, self.content)
return content
class MessageTable(object):
@ -371,8 +400,7 @@ class MessageTable(object):
self.recalculate_size_index(args)
def add(self, *args):
convert = lambda v: v if isinstance(v, basestring) else str(v)
args = list(map(convert, args))
args = list(map(lambda v: v if isinstance(v, basestring) else str(v), args))
self.entries.append(args)
self.recalculate_size_index(args)

Loading…
Cancel
Save