diff --git a/disco/bot/command.py b/disco/bot/command.py index 08c9b9f..1cf00a1 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -7,7 +7,10 @@ from disco.util.functional import cached_property REGEX_FMT = '({})' ARGS_REGEX = '( ((?:\n|.)*)$|$)' -MENTION_RE = re.compile('<@!?([0-9]+)>') + +USER_MENTION_RE = re.compile('<@!?([0-9]+)>') +ROLE_MENTION_RE = re.compile('<@&([0-9]+)>') +CHANNEL_MENTION_RE = re.compile('<#([0-9]+)>') CommandLevels = Enum( DEFAULT=0, @@ -145,10 +148,13 @@ class Command(object): else: return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1]) + def resolve_channel(ctx, cid): + return ctx.msg.guild.channels.get(cid) + self.args = ArgumentSet.from_string(args or '', { - 'mention': self.mention_type([resolve_role, resolve_user]), - 'user': self.mention_type([resolve_user], force=True, user=True), - 'role': self.mention_type([resolve_role], force=True), + 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True), + 'role': self.mention_type([resolve_role], ROLE_MENTION_RE), + 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE), }) self.level = level @@ -159,7 +165,7 @@ class Command(object): self.dispatch_func = dispatch_func @staticmethod - def mention_type(getters, force=False, user=False): + def mention_type(getters, reg, user=False): def _f(ctx, raw): if raw.isdigit(): resolved = int(raw) @@ -167,7 +173,7 @@ class Command(object): username, discrim = raw.split('#') resolved = (username, int(discrim)) else: - res = MENTION_RE.match(raw) + res = reg.match(raw) if not res: raise TypeError('Invalid mention: {}'.format(raw)) @@ -178,10 +184,7 @@ class Command(object): if obj: return obj - if force: - raise TypeError('Cannot resolve mention: {}'.format(raw)) - - return resolved + raise TypeError('Cannot resolve mention: {}'.format(raw)) return _f @cached_property diff --git a/disco/types/base.py b/disco/types/base.py index 56657ff..dfa1e2c 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -73,6 +73,7 @@ class Field(object): try: return self.deserializer(raw, client) except Exception as e: + raise six.reraise(ConversionError, ConversionError(self, raw, e)) @staticmethod @@ -281,11 +282,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): self.load(obj) @property - def fields(self): + def _fields(self): return self.__class__._fields def load(self, obj, consume=False, skip=None): - for name, field in six.iteritems(self.fields): + for name, field in six.iteritems(self._fields): should_skip = skip and name in skip if consume and not should_skip: @@ -305,7 +306,7 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self.fields): + for name in six.iterkeys(self._fields): if hasattr(other, name) and not getattr(other, name) is UNSET: setattr(self, name, getattr(other, name))