Browse Source

Various fixes and improvements

- Add support for attachments and message embeds
- Fix commands being weirdly stored by some key (which doesn't make
sense)
-  Added CommandEvent.codeblock which represents the first codeblock in
the message (useful for eval like commands)
- Cleanup the spawn utilties on plugin a bit
- Fix GuildBanAdd/GuildBanRemove
- Unset model fields are now a special sentinel value
- etc stuff
pull/11/head
Andrei 9 years ago
parent
commit
b5284c1975
  1. 16
      disco/api/client.py
  2. 2
      disco/bot/bot.py
  3. 26
      disco/bot/command.py
  4. 42
      disco/bot/plugin.py
  5. 20
      disco/gateway/events.py
  6. 18
      disco/state.py
  7. 1
      disco/types/__init__.py
  8. 28
      disco/types/base.py
  9. 9
      disco/types/channel.py
  10. 18
      disco/types/user.py

16
disco/api/client.py

@ -77,12 +77,22 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content, nonce=None, tts=False): def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None):
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={ payload = {
'content': content, 'content': content,
'nonce': nonce, 'nonce': nonce,
'tts': tts, 'tts': tts,
}) }
if embed:
payload['embed'] = embed.to_dict()
if attachment:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data=payload, files={
'file': (attachment[0], attachment[1])
})
else:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload)
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())

2
disco/bot/bot.py

@ -180,7 +180,7 @@ class Bot(object):
Generator of all commands this bots plugins have defined. Generator of all commands this bots plugins have defined.
""" """
for plugin in six.itervalues(self.plugins): for plugin in six.itervalues(self.plugins):
for command in six.itervalues(plugin.commands): for command in plugin.commands:
yield command yield command
def recompute(self): def recompute(self):

26
disco/bot/command.py

@ -45,6 +45,18 @@ class CommandEvent(object):
self.name = self.match.group(1) self.name = self.match.group(1)
self.args = [i for i in self.match.group(2).strip().split(' ') if i] self.args = [i for i in self.match.group(2).strip().split(' ') if i]
@property
def codeblock(self):
_, src = self.msg.content.split('`', 1)
src = '`' + src
if src.startswith('```') and src.endswith('```'):
src = src[3:-3]
elif src.startswith('`') and src.endswith('`'):
src = src[1:-1]
return src
@cached_property @cached_property
def member(self): def member(self):
""" """
@ -146,11 +158,15 @@ class Command(object):
@staticmethod @staticmethod
def mention_type(getters, force=False): def mention_type(getters, force=False):
def _f(ctx, i): def _f(ctx, i):
res = MENTION_RE.match(i) # TODO: support full discrim format? make this betteR?
if not res: if i.isdigit():
raise TypeError('Invalid mention: {}'.format(i)) mid = int(i)
else:
mid = int(res.group(1)) res = MENTION_RE.match(i)
if not res:
raise TypeError('Invalid mention: {}'.format(i))
mid = int(res.group(1))
for getter in getters: for getter in getters:
obj = getter(ctx, mid) obj = getter(ctx, mid)

42
disco/bot/plugin.py

@ -156,7 +156,7 @@ class Plugin(LoggingClass, PluginDeco):
# General declartions # General declartions
self.listeners = [] self.listeners = []
self.commands = {} self.commands = []
self.schedules = {} self.schedules = {}
self.greenlets = weakref.WeakSet() self.greenlets = weakref.WeakSet()
self._pre = {} self._pre = {}
@ -182,7 +182,7 @@ class Plugin(LoggingClass, PluginDeco):
def bind_all(self): def bind_all(self):
self.listeners = [] self.listeners = []
self.commands = {} self.commands = []
self.schedules = {} self.schedules = {}
self.greenlets = weakref.WeakSet() self.greenlets = weakref.WeakSet()
@ -197,7 +197,7 @@ class Plugin(LoggingClass, PluginDeco):
if meta['type'] == 'listener': if meta['type'] == 'listener':
self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs']) self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs'])
elif meta['type'] == 'command': elif meta['type'] == 'command':
meta['kwargs']['update'] = True # meta['kwargs']['update'] = True
self.register_command(member, *meta['args'], **meta['kwargs']) self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'] == 'schedule': elif meta['type'] == 'schedule':
self.register_schedule(member, *meta['args'], **meta['kwargs']) self.register_schedule(member, *meta['args'], **meta['kwargs'])
@ -205,11 +205,25 @@ class Plugin(LoggingClass, PluginDeco):
when, typ = meta['type'].split('_', 1) when, typ = meta['type'].split('_', 1)
self.register_trigger(typ, when, member) self.register_trigger(typ, when, member)
def spawn(self, method, *args, **kwargs): def spawn_wrap(self, spawner, method, *args, **kwargs):
obj = gevent.spawn(method, *args, **kwargs) def wrapped(*args, **kwargs):
self.ctx['plugin'] = self
try:
res = method(*args, **kwargs)
return res
finally:
self.ctx.drop()
obj = spawner(wrapped, *args, **kwargs)
self.greenlets.add(obj) self.greenlets.add(obj)
return obj return obj
def spawn(self, *args, **kwargs):
return self.spawn_wrap(gevent.spawn, *args, **kwargs)
def spawn_later(self, delay, *args, **kwargs):
return self.spawn_wrap(functools.partial(gevent.spawn_later, delay), *args, **kwargs)
def execute(self, event): def execute(self, event):
""" """
Executes a CommandEvent this plugin owns. Executes a CommandEvent this plugin owns.
@ -294,14 +308,14 @@ class Plugin(LoggingClass, PluginDeco):
Keyword arguments to pass onto the :class:`disco.bot.command.Command` Keyword arguments to pass onto the :class:`disco.bot.command.Command`
object. object.
""" """
name = args[0] # name = args[0]
if kwargs.pop('update', False) and name in self.commands: # if kwargs.pop('update', False) and name in self.commands:
self.commands[name].update(*args, **kwargs) # self.commands[name].update(*args, **kwargs)
else: # else:
wrapped = functools.partial(self._dispatch, 'command', func) wrapped = functools.partial(self._dispatch, 'command', func)
kwargs.setdefault('dispatch_func', wrapped) kwargs.setdefault('dispatch_func', wrapped)
self.commands[name] = Command(self, func, *args, **kwargs) self.commands.append(Command(self, func, *args, **kwargs))
def register_schedule(self, func, interval, repeat=True, init=True): def register_schedule(self, func, interval, repeat=True, init=True):
""" """
@ -320,7 +334,7 @@ class Plugin(LoggingClass, PluginDeco):
Whether to run this schedule once immediatly, or wait for the first Whether to run this schedule once immediatly, or wait for the first
scheduled iteration. scheduled iteration.
""" """
def func(): def repeat_func():
if init: if init:
func() func()
@ -330,7 +344,7 @@ class Plugin(LoggingClass, PluginDeco):
if not repeat: if not repeat:
break break
self.schedules[func.__name__] = self.spawn(repeat) self.schedules[func.__name__] = self.spawn(repeat_func)
def load(self, ctx): def load(self, ctx):
""" """

20
disco/gateway/events.py

@ -67,15 +67,16 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)):
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
def debug(func=None): def debug(func=None, match=None):
def deco(cls): def deco(cls):
old_init = cls.__init__ old_init = cls.__init__
def new_init(self, obj, *args, **kwargs): def new_init(self, obj, *args, **kwargs):
if func: if not match or match(obj):
print(func(obj)) if func:
else: print(func(obj))
print(obj) else:
print(obj)
old_init(self, obj, *args, **kwargs) old_init(self, obj, *args, **kwargs)
@ -244,7 +245,7 @@ class ChannelPinsUpdate(GatewayEvent):
last_pin_timestamp = Field(lazy_datetime) last_pin_timestamp = Field(lazy_datetime)
@wraps_model(User) @proxy(User)
class GuildBanAdd(GatewayEvent): class GuildBanAdd(GatewayEvent):
""" """
Sent when a user is banned from a guild. Sent when a user is banned from a guild.
@ -257,13 +258,14 @@ class GuildBanAdd(GatewayEvent):
The user being banned from the guild. The user being banned from the guild.
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
user = Field(User)
@property @property
def guild(self): def guild(self):
return self.client.state.guilds.get(self.guild_id) return self.client.state.guilds.get(self.guild_id)
@wraps_model(User) @proxy(User)
class GuildBanRemove(GuildBanAdd): class GuildBanRemove(GuildBanAdd):
""" """
Sent when a user is unbanned from a guild. Sent when a user is unbanned from a guild.
@ -507,6 +509,10 @@ class PresenceUpdate(GatewayEvent):
guild_id = Field(snowflake) guild_id = Field(snowflake)
roles = ListField(snowflake) roles = ListField(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class TypingStart(GatewayEvent): class TypingStart(GatewayEvent):
""" """

18
disco/state.py

@ -1,8 +1,8 @@
import six import six
import weakref
import inflection import inflection
from collections import deque, namedtuple from collections import deque, namedtuple
from weakref import WeakValueDictionary
from gevent.event import Event from gevent.event import Event
from disco.util.config import Config from disco.util.config import Config
@ -102,9 +102,9 @@ class State(object):
self.me = None self.me = None
self.dms = HashMap() self.dms = HashMap()
self.guilds = HashMap() self.guilds = HashMap()
self.channels = HashMap(WeakValueDictionary()) self.channels = HashMap(weakref.WeakValueDictionary())
self.users = HashMap(WeakValueDictionary()) self.users = HashMap(weakref.WeakValueDictionary())
self.voice_states = HashMap(WeakValueDictionary()) self.voice_states = HashMap(weakref.WeakValueDictionary())
# If message tracking is enabled, listen to those events # If message tracking is enabled, listen to those events
if self.config.track_messages: if self.config.track_messages:
@ -298,4 +298,14 @@ class State(object):
def on_presence_update(self, event): def on_presence_update(self, event):
if event.user.id in self.users: if event.user.id in self.users:
self.users[event.user.id].update(event.presence.user)
self.users[event.user.id].presence = event.presence self.users[event.user.id].presence = event.presence
event.presence.user = self.users[event.user.id]
if event.guild_id not in self.guilds:
return
if event.user.id not in self.guilds[event.guild_id].members:
return
self.guilds[event.guild_id].members[event.user.id].user.update(event.user)

1
disco/types/__init__.py

@ -1,3 +1,4 @@
from disco.types.base import UNSET
from disco.types.channel import Channel from disco.types.channel import Channel
from disco.types.guild import Guild, GuildMember, Role from disco.types.guild import Guild, GuildMember, Role
from disco.types.user import User from disco.types.user import User

28
disco/types/base.py

@ -15,6 +15,14 @@ DATETIME_FORMATS = [
] ]
class Unset(object):
def __nonzero__(self):
return False
UNSET = Unset()
class ConversionError(Exception): class ConversionError(Exception):
def __init__(self, field, raw, e): def __init__(self, field, raw, e):
super(ConversionError, self).__init__( super(ConversionError, self).__init__(
@ -26,10 +34,9 @@ class ConversionError(Exception):
class Field(object): class Field(object):
def __init__(self, value_type, alias=None, default=None, test=0): def __init__(self, value_type, alias=None, default=None):
self.src_name = alias self.src_name = alias
self.dst_name = None self.dst_name = None
self.test = test
if default is not None: if default is not None:
self.default = default self.default = default
@ -97,6 +104,10 @@ class DictField(Field):
self.key_de = self.type_to_deserializer(key_type) self.key_de = self.type_to_deserializer(key_type)
self.value_de = self.type_to_deserializer(value_type or key_type) self.value_de = self.type_to_deserializer(value_type or key_type)
@staticmethod
def serialize(value):
return {Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)}
def try_convert(self, raw, client): def try_convert(self, raw, client):
return HashMap({ return HashMap({
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
@ -106,6 +117,10 @@ class DictField(Field):
class ListField(Field): class ListField(Field):
default = list default = list
@staticmethod
def serialize(value):
return list(map(Field.serialize, value))
def try_convert(self, raw, client): def try_convert(self, raw, client):
return [self.deserializer(i, client) for i in raw] return [self.deserializer(i, client) for i in raw]
@ -265,7 +280,7 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
if field.has_default(): if field.has_default():
default = field.default() if callable(field.default) else field.default default = field.default() if callable(field.default) else field.default
else: else:
default = None default = UNSET
setattr(self, field.dst_name, default) setattr(self, field.dst_name, default)
continue continue
@ -274,9 +289,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
def update(self, other): def update(self, other):
for name in six.iterkeys(self.__class__._fields): for name in six.iterkeys(self.__class__._fields):
value = getattr(other, name) if hasattr(other, name) and not getattr(other, name) is UNSET:
if value: setattr(self, name, getattr(other, name))
setattr(self, name, value)
# Clear cached properties # Clear cached properties
for name in dir(type(self)): for name in dir(type(self)):
@ -289,6 +303,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
def to_dict(self): def to_dict(self):
obj = {} obj = {}
for name, field in six.iteritems(self.__class__._fields): for name, field in six.iteritems(self.__class__._fields):
if getattr(self, name) == UNSET:
continue
obj[name] = field.serialize(getattr(self, name)) obj[name] = field.serialize(getattr(self, name))
return obj return obj

9
disco/types/channel.py

@ -121,7 +121,10 @@ class Channel(SlottedModel, Permissible):
self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self})
def __str__(self): def __str__(self):
return '#{}'.format(self.name) return u'#{}'.format(self.name)
def __repr__(self):
return u'<Channel {} ({})>'.format(self.id, self)
def get_permissions(self, user): def get_permissions(self, user):
""" """
@ -230,7 +233,7 @@ class Channel(SlottedModel, Permissible):
def create_webhook(self, name=None, avatar=None): def create_webhook(self, name=None, avatar=None):
return self.client.api.channels_webhooks_create(self.id, name, avatar) return self.client.api.channels_webhooks_create(self.id, name, avatar)
def send_message(self, content, nonce=None, tts=False): def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None):
""" """
Send a message in this channel. Send a message in this channel.
@ -248,7 +251,7 @@ class Channel(SlottedModel, Permissible):
:class:`disco.types.message.Message` :class:`disco.types.message.Message`
The created message. The created message.
""" """
return self.client.api.channels_messages_create(self.id, content, nonce, tts) return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed)
def connect(self, *args, **kwargs): def connect(self, *args, **kwargs):
""" """

18
disco/types/user.py

@ -14,18 +14,24 @@ class User(SlottedModel, with_equality('id'), with_hash('id')):
presence = Field(None) presence = Field(None)
@property
def avatar_url(self):
if not self.avatar:
return None
return 'https://discordapp.com/api/users/{}/avatars/{}.jpg'.format(
self.id,
self.avatar)
@property @property
def mention(self): def mention(self):
return '<@{}>'.format(self.id) return '<@{}>'.format(self.id)
def __str__(self): def __str__(self):
return '{}#{}'.format(self.username, self.discriminator) return u'{}#{}'.format(self.username, self.discriminator)
def __repr__(self): def __repr__(self):
return '<User {} ({})>'.format(self.id, self.to_string()) return u'<User {} ({})>'.format(self.id, self)
def on_create(self):
self.client.state.users[self.id] = self
GameType = Enum( GameType = Enum(
@ -49,6 +55,6 @@ class Game(SlottedModel):
class Presence(SlottedModel): class Presence(SlottedModel):
user = Field(User) user = Field(User, alias='user')
game = Field(Game) game = Field(Game)
status = Field(Status) status = Field(Status)

Loading…
Cancel
Save