diff --git a/disco/api/client.py b/disco/api/client.py index 27152e4..3a15171 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -77,12 +77,22 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) return Message.create(self.client, r.json()) - def channels_messages_create(self, channel, content, nonce=None, tts=False): - r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={ + def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None): + payload = { 'content': content, 'nonce': nonce, 'tts': tts, - }) + } + + if embed: + payload['embed'] = embed.to_dict() + + if attachment: + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data=payload, files={ + 'file': (attachment[0], attachment[1]) + }) + else: + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload) return Message.create(self.client, r.json()) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 428f76a..ec670e1 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -180,7 +180,7 @@ class Bot(object): Generator of all commands this bots plugins have defined. """ for plugin in six.itervalues(self.plugins): - for command in six.itervalues(plugin.commands): + for command in plugin.commands: yield command def recompute(self): diff --git a/disco/bot/command.py b/disco/bot/command.py index e8ef431..40ccbea 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -45,6 +45,18 @@ class CommandEvent(object): self.name = self.match.group(1) self.args = [i for i in self.match.group(2).strip().split(' ') if i] + @property + def codeblock(self): + _, src = self.msg.content.split('`', 1) + src = '`' + src + + if src.startswith('```') and src.endswith('```'): + src = src[3:-3] + elif src.startswith('`') and src.endswith('`'): + src = src[1:-1] + + return src + @cached_property def member(self): """ @@ -146,11 +158,15 @@ class Command(object): @staticmethod def mention_type(getters, force=False): def _f(ctx, i): - res = MENTION_RE.match(i) - if not res: - raise TypeError('Invalid mention: {}'.format(i)) - - mid = int(res.group(1)) + # TODO: support full discrim format? make this betteR? + if i.isdigit(): + mid = int(i) + else: + res = MENTION_RE.match(i) + if not res: + raise TypeError('Invalid mention: {}'.format(i)) + + mid = int(res.group(1)) for getter in getters: obj = getter(ctx, mid) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 0e05136..fcbc9ad 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -156,7 +156,7 @@ class Plugin(LoggingClass, PluginDeco): # General declartions self.listeners = [] - self.commands = {} + self.commands = [] self.schedules = {} self.greenlets = weakref.WeakSet() self._pre = {} @@ -182,7 +182,7 @@ class Plugin(LoggingClass, PluginDeco): def bind_all(self): self.listeners = [] - self.commands = {} + self.commands = [] self.schedules = {} self.greenlets = weakref.WeakSet() @@ -197,7 +197,7 @@ class Plugin(LoggingClass, PluginDeco): if meta['type'] == 'listener': self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs']) elif meta['type'] == 'command': - meta['kwargs']['update'] = True + # meta['kwargs']['update'] = True self.register_command(member, *meta['args'], **meta['kwargs']) elif meta['type'] == 'schedule': self.register_schedule(member, *meta['args'], **meta['kwargs']) @@ -205,11 +205,25 @@ class Plugin(LoggingClass, PluginDeco): when, typ = meta['type'].split('_', 1) self.register_trigger(typ, when, member) - def spawn(self, method, *args, **kwargs): - obj = gevent.spawn(method, *args, **kwargs) + def spawn_wrap(self, spawner, method, *args, **kwargs): + def wrapped(*args, **kwargs): + self.ctx['plugin'] = self + try: + res = method(*args, **kwargs) + return res + finally: + self.ctx.drop() + + obj = spawner(wrapped, *args, **kwargs) self.greenlets.add(obj) return obj + def spawn(self, *args, **kwargs): + return self.spawn_wrap(gevent.spawn, *args, **kwargs) + + def spawn_later(self, delay, *args, **kwargs): + return self.spawn_wrap(functools.partial(gevent.spawn_later, delay), *args, **kwargs) + def execute(self, event): """ Executes a CommandEvent this plugin owns. @@ -294,14 +308,14 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - name = args[0] + # name = args[0] - if kwargs.pop('update', False) and name in self.commands: - self.commands[name].update(*args, **kwargs) - else: - wrapped = functools.partial(self._dispatch, 'command', func) - kwargs.setdefault('dispatch_func', wrapped) - self.commands[name] = Command(self, func, *args, **kwargs) + # if kwargs.pop('update', False) and name in self.commands: + # self.commands[name].update(*args, **kwargs) + # else: + wrapped = functools.partial(self._dispatch, 'command', func) + kwargs.setdefault('dispatch_func', wrapped) + self.commands.append(Command(self, func, *args, **kwargs)) def register_schedule(self, func, interval, repeat=True, init=True): """ @@ -320,7 +334,7 @@ class Plugin(LoggingClass, PluginDeco): Whether to run this schedule once immediatly, or wait for the first scheduled iteration. """ - def func(): + def repeat_func(): if init: func() @@ -330,7 +344,7 @@ class Plugin(LoggingClass, PluginDeco): if not repeat: break - self.schedules[func.__name__] = self.spawn(repeat) + self.schedules[func.__name__] = self.spawn(repeat_func) def load(self, ctx): """ diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 3e8af68..5fc28cf 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -67,15 +67,16 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): return object.__getattribute__(self, name) -def debug(func=None): +def debug(func=None, match=None): def deco(cls): old_init = cls.__init__ def new_init(self, obj, *args, **kwargs): - if func: - print(func(obj)) - else: - print(obj) + if not match or match(obj): + if func: + print(func(obj)) + else: + print(obj) old_init(self, obj, *args, **kwargs) @@ -244,7 +245,7 @@ class ChannelPinsUpdate(GatewayEvent): last_pin_timestamp = Field(lazy_datetime) -@wraps_model(User) +@proxy(User) class GuildBanAdd(GatewayEvent): """ Sent when a user is banned from a guild. @@ -257,13 +258,14 @@ class GuildBanAdd(GatewayEvent): The user being banned from the guild. """ guild_id = Field(snowflake) + user = Field(User) @property def guild(self): return self.client.state.guilds.get(self.guild_id) -@wraps_model(User) +@proxy(User) class GuildBanRemove(GuildBanAdd): """ Sent when a user is unbanned from a guild. @@ -507,6 +509,10 @@ class PresenceUpdate(GatewayEvent): guild_id = Field(snowflake) roles = ListField(snowflake) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + class TypingStart(GatewayEvent): """ diff --git a/disco/state.py b/disco/state.py index cf4636e..81efbb3 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,8 +1,8 @@ import six +import weakref import inflection from collections import deque, namedtuple -from weakref import WeakValueDictionary from gevent.event import Event from disco.util.config import Config @@ -102,9 +102,9 @@ class State(object): self.me = None self.dms = HashMap() self.guilds = HashMap() - self.channels = HashMap(WeakValueDictionary()) - self.users = HashMap(WeakValueDictionary()) - self.voice_states = HashMap(WeakValueDictionary()) + self.channels = HashMap(weakref.WeakValueDictionary()) + self.users = HashMap(weakref.WeakValueDictionary()) + self.voice_states = HashMap(weakref.WeakValueDictionary()) # If message tracking is enabled, listen to those events if self.config.track_messages: @@ -298,4 +298,14 @@ class State(object): def on_presence_update(self, event): if event.user.id in self.users: + self.users[event.user.id].update(event.presence.user) self.users[event.user.id].presence = event.presence + event.presence.user = self.users[event.user.id] + + if event.guild_id not in self.guilds: + return + + if event.user.id not in self.guilds[event.guild_id].members: + return + + self.guilds[event.guild_id].members[event.user.id].user.update(event.user) diff --git a/disco/types/__init__.py b/disco/types/__init__.py index 5e6f73b..5824ec5 100644 --- a/disco/types/__init__.py +++ b/disco/types/__init__.py @@ -1,3 +1,4 @@ +from disco.types.base import UNSET from disco.types.channel import Channel from disco.types.guild import Guild, GuildMember, Role from disco.types.user import User diff --git a/disco/types/base.py b/disco/types/base.py index 17fde2e..1013906 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -15,6 +15,14 @@ DATETIME_FORMATS = [ ] +class Unset(object): + def __nonzero__(self): + return False + + +UNSET = Unset() + + class ConversionError(Exception): def __init__(self, field, raw, e): super(ConversionError, self).__init__( @@ -26,10 +34,9 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None, test=0): + def __init__(self, value_type, alias=None, default=None): self.src_name = alias self.dst_name = None - self.test = test if default is not None: self.default = default @@ -97,6 +104,10 @@ class DictField(Field): self.key_de = self.type_to_deserializer(key_type) self.value_de = self.type_to_deserializer(value_type or key_type) + @staticmethod + def serialize(value): + return {Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)} + def try_convert(self, raw, client): return HashMap({ self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) @@ -106,6 +117,10 @@ class DictField(Field): class ListField(Field): default = list + @staticmethod + def serialize(value): + return list(map(Field.serialize, value)) + def try_convert(self, raw, client): return [self.deserializer(i, client) for i in raw] @@ -265,7 +280,7 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): if field.has_default(): default = field.default() if callable(field.default) else field.default else: - default = None + default = UNSET setattr(self, field.dst_name, default) continue @@ -274,9 +289,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): def update(self, other): for name in six.iterkeys(self.__class__._fields): - value = getattr(other, name) - if value: - setattr(self, name, value) + if hasattr(other, name) and not getattr(other, name) is UNSET: + setattr(self, name, getattr(other, name)) # Clear cached properties for name in dir(type(self)): @@ -289,6 +303,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): def to_dict(self): obj = {} for name, field in six.iteritems(self.__class__._fields): + if getattr(self, name) == UNSET: + continue obj[name] = field.serialize(getattr(self, name)) return obj diff --git a/disco/types/channel.py b/disco/types/channel.py index f5389d1..3912535 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -121,7 +121,10 @@ class Channel(SlottedModel, Permissible): self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) def __str__(self): - return '#{}'.format(self.name) + return u'#{}'.format(self.name) + + def __repr__(self): + return u''.format(self.id, self) def get_permissions(self, user): """ @@ -230,7 +233,7 @@ class Channel(SlottedModel, Permissible): def create_webhook(self, name=None, avatar=None): return self.client.api.channels_webhooks_create(self.id, name, avatar) - def send_message(self, content, nonce=None, tts=False): + def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None): """ Send a message in this channel. @@ -248,7 +251,7 @@ class Channel(SlottedModel, Permissible): :class:`disco.types.message.Message` The created message. """ - return self.client.api.channels_messages_create(self.id, content, nonce, tts) + return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed) def connect(self, *args, **kwargs): """ diff --git a/disco/types/user.py b/disco/types/user.py index a860843..c2bac79 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -14,18 +14,24 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): presence = Field(None) + @property + def avatar_url(self): + if not self.avatar: + return None + + return 'https://discordapp.com/api/users/{}/avatars/{}.jpg'.format( + self.id, + self.avatar) + @property def mention(self): return '<@{}>'.format(self.id) def __str__(self): - return '{}#{}'.format(self.username, self.discriminator) + return u'{}#{}'.format(self.username, self.discriminator) def __repr__(self): - return ''.format(self.id, self.to_string()) - - def on_create(self): - self.client.state.users[self.id] = self + return u''.format(self.id, self) GameType = Enum( @@ -49,6 +55,6 @@ class Game(SlottedModel): class Presence(SlottedModel): - user = Field(User) + user = Field(User, alias='user') game = Field(Game) status = Field(Status)