Browse Source

todo: make small commits

- Add the concept of storage backends, not fully fleshed out at this
point, but a good starting point
- Add a generic serializer
- Move mention_nick to the GuildMember object (I'm not sure this was a
good idea, but we'll see)
- Add a default config loader to the bot
- Fix some Python 2.x/3.x unicode stuff
- Start tracking greenlets on the Plugin level, this will help with
reloading when its fully completed
- Fix manhole locals being basically empty (sans the bot if relevant)
- Add Channel.delete_messages_bulk
- Add GuildMember.owner to check if the member owns the server
pull/5/head
Andrei 9 years ago
parent
commit
7d5370234d
  1. 3
      disco/bot/__init__.py
  2. 8
      disco/bot/backends/__init__.py
  3. 20
      disco/bot/backends/base.py
  4. 35
      disco/bot/backends/disk.py
  5. 18
      disco/bot/backends/memory.py
  6. 63
      disco/bot/bot.py
  7. 5
      disco/bot/parser.py
  8. 64
      disco/bot/plugin.py
  9. 21
      disco/bot/storage.py
  10. 8
      disco/client.py
  11. 10
      disco/types/base.py
  12. 5
      disco/types/channel.py
  13. 10
      disco/types/guild.py
  14. 4
      disco/types/user.py
  15. 42
      disco/util/config.py
  16. 32
      disco/util/serializer.py
  17. 12
      examples/basic_plugin.py

3
disco/bot/__init__.py

@ -1,4 +1,5 @@
from disco.bot.bot import Bot, BotConfig from disco.bot.bot import Bot, BotConfig
from disco.bot.plugin import Plugin from disco.bot.plugin import Plugin
from disco.util.config import Config
__all__ = ['Bot', 'BotConfig', 'Plugin'] __all__ = ['Bot', 'BotConfig', 'Plugin', 'Config']

8
disco/bot/backends/__init__.py

@ -0,0 +1,8 @@
from .memory import MemoryBackend
from .disk import DiskBackend
BACKENDS = {
'memory': MemoryBackend,
'disk': DiskBackend,
}

20
disco/bot/backends/base.py

@ -0,0 +1,20 @@
class BaseStorageBackend(object):
def base(self):
return self.storage
def __getitem__(self, key):
return self.storage[key]
def __setitem__(self, key, value):
self.storage[key] = value
def __delitem__(self, key):
del self.storage[key]
class StorageDict(dict):
def ensure(self, name):
if not dict.__contains__(self, name):
dict.__setitem__(self, name, StorageDict())
return dict.__getitem__(self, name)

35
disco/bot/backends/disk.py

@ -0,0 +1,35 @@
import os
from .base import BaseStorageBackend, StorageDict
class DiskBackend(BaseStorageBackend):
def __init__(self, config):
self.format = config.get('format', 'json')
self.path = config.get('path', 'storage') + '.' + self.format
self.storage = StorageDict()
@staticmethod
def get_format_functions(fmt):
if fmt == 'json':
from json import loads, dumps
return (loads, dumps)
elif fmt == 'yaml':
from pyyaml import load, dump
return (load, dump)
raise Exception('Unsupported format type {}'.format(fmt))
def load(self):
if not os.path.exists(self.path):
return
decode, _ = self.get_format_functions(self.format)
with open(self.path, 'r') as f:
self.storage = decode(f.read())
def dump(self):
_, encode = self.get_format_functions(self.format)
with open(self.path, 'w') as f:
f.write(encode(self.storage))

18
disco/bot/backends/memory.py

@ -0,0 +1,18 @@
from .base import BaseStorageBackend, StorageDict
class MemoryBackend(BaseStorageBackend):
def __init__(self):
self.storage = StorageDict()
def base(self):
return self.storage
def __getitem__(self, key):
return self.storage[key]
def __setitem__(self, key, value):
self.storage[key] = value
def __delitem__(self, key):
del self.storage[key]

63
disco/bot/bot.py

@ -1,4 +1,5 @@
import re import re
import os
import importlib import importlib
import inspect import inspect
@ -7,10 +8,12 @@ from holster.threadlocal import ThreadLocal
from disco.bot.plugin import Plugin from disco.bot.plugin import Plugin
from disco.bot.command import CommandEvent from disco.bot.command import CommandEvent
# from disco.bot.storage import Storage from disco.bot.storage import Storage
from disco.util.config import Config
from disco.util.serializer import Serializer
class BotConfig(object): class BotConfig(Config):
""" """
An object which is used to configure and define the runtime configuration for An object which is used to configure and define the runtime configuration for
a bot. a bot.
@ -40,9 +43,14 @@ class BotConfig(object):
message in a channel, and did not previously trigger a command. This is message in a channel, and did not previously trigger a command. This is
helpful for allowing edits to typod commands. helpful for allowing edits to typod commands.
plugin_config_provider : Optional[function] plugin_config_provider : Optional[function]
If set, this function will be called before loading a plugin, with the If set, this function will replace the default configuration loading
plugins class. Its expected to return a type of configuration object the function, which normally attempts to load a file located at config/plugin_name.fmt
plugin understands. where fmt is the plugin_config_format. The function here should return
a valid configuration object which the plugin understands.
plugin_config_format : str
The serilization format plugin configuration files are in.
plugin_config_dir : str
The directory plugin configuration is located within.
""" """
token = None token = None
@ -58,6 +66,13 @@ class BotConfig(object):
commands_allow_edit = True commands_allow_edit = True
plugin_config_provider = None plugin_config_provider = None
plugin_config_format = 'yaml'
plugin_config_dir = 'config'
storage_enabled = False
storage_backend = 'memory'
storage_autosave = True
storage_autosave_interval = 120
class Bot(object): class Bot(object):
@ -90,7 +105,9 @@ class Bot(object):
self.ctx = ThreadLocal() self.ctx = ThreadLocal()
# The storage object acts as a dynamic contextual aware store # The storage object acts as a dynamic contextual aware store
# self.storage = Storage(self.ctx) self.storage = None
if self.config.storage_enabled:
self.storage = Storage(self.ctx, self.config.from_prefix('storage'))
if self.client.config.manhole_enable: if self.client.config.manhole_enable:
self.client.manhole_locals['bot'] = self self.client.manhole_locals['bot'] = self
@ -181,8 +198,12 @@ class Bot(object):
raise StopIteration raise StopIteration
if mention_direct: if mention_direct:
if msg.guild:
member = msg.guild.get_member(self.client.state.me)
if member:
content = content.replace(member.mention, '', 1)
else:
content = content.replace(self.client.state.me.mention, '', 1) content = content.replace(self.client.state.me.mention, '', 1)
content = content.replace(self.client.state.me.mention_nick, '', 1)
elif mention_everyone: elif mention_everyone:
content = content.replace('@everyone', '', 1) content = content.replace('@everyone', '', 1)
else: else:
@ -265,8 +286,11 @@ class Bot(object):
if cls.__name__ in self.plugins: if cls.__name__ in self.plugins:
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))
if not config and callable(self.config.plugin_config_provider): if not config:
if callable(self.config.plugin_config_provider):
config = self.config.plugin_config_provider(cls) config = self.config.plugin_config_provider(cls)
else:
config = self.load_plugin_config(cls)
self.plugins[cls.__name__] = cls(self, config) self.plugins[cls.__name__] = cls(self, config)
self.plugins[cls.__name__].load() self.plugins[cls.__name__].load()
@ -317,3 +341,26 @@ class Bot(object):
break break
else: else:
raise Exception('Could not find any plugins to load within module {}'.format(path)) raise Exception('Could not find any plugins to load within module {}'.format(path))
def load_plugin_config(self, cls):
name = cls.__name__.lower()
if name.startswith('plugin'):
name = name[6:]
path = os.path.join(
self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format
if not os.path.exists(path):
if hasattr(cls, 'config_cls'):
return cls.config_cls()
return
with open(path, 'r') as f:
data = Serializer.loads(self.config.plugin_config_format, f.read())
if hasattr(cls, 'config_cls'):
inst = cls.config_cls()
inst.update(data)
return inst
return data

5
disco/bot/parser.py

@ -1,4 +1,5 @@
import re import re
import six
import copy import copy
@ -7,7 +8,7 @@ PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])')
# Mapping of types # Mapping of types
TYPE_MAP = { TYPE_MAP = {
'str': lambda ctx, data: str(data), 'str': lambda ctx, data: str(data) if six.PY3 else unicode(data),
'int': lambda ctx, data: int(data), 'int': lambda ctx, data: int(data),
'float': lambda ctx, data: int(data), 'float': lambda ctx, data: int(data),
'snowflake': lambda ctx, data: int(data), 'snowflake': lambda ctx, data: int(data),
@ -160,7 +161,7 @@ class ArgumentSet(object):
try: try:
raw[idx] = self.convert(ctx, arg.types, r) raw[idx] = self.convert(ctx, arg.types, r)
except: except:
raise ArgumentError('cannot convert `{}` to `{}`'.format( raise ArgumentError(u'cannot convert `{}` to `{}`'.format(
r, ', '.join(arg.types) r, ', '.join(arg.types)
)) ))

64
disco/bot/plugin.py

@ -1,7 +1,7 @@
import inspect import inspect
import functools import functools
import gevent import gevent
import os import weakref
from holster.emitter import Priority from holster.emitter import Priority
@ -27,6 +27,16 @@ class PluginDeco(object):
return f return f
return deco return deco
@classmethod
def with_config(cls, config_cls):
"""
Sets the plugins config class to the specified config class.
"""
def deco(plugin_cls):
plugin_cls.config_cls = config_cls
return plugin_cls
return deco
@classmethod @classmethod
def listen(cls, event_name, priority=None): def listen(cls, event_name, priority=None):
""" """
@ -86,13 +96,14 @@ class PluginDeco(object):
}) })
@classmethod @classmethod
def schedule(cls, interval=60): def schedule(cls, *args, **kwargs):
""" """
Runs a function repeatedly, waiting for a specified interval Runs a function repeatedly, waiting for a specified interval
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'schedule', 'type': 'schedule',
'interval': interval, 'args': args,
'kwargs': kwargs,
}) })
@ -131,10 +142,15 @@ class Plugin(LoggingClass, PluginDeco):
self.listeners = [] self.listeners = []
self.commands = {} self.commands = {}
self.schedules = {} self.schedules = {}
self.greenlets = weakref.WeakSet()
self._pre = {'command': [], 'listener': []} self._pre = {'command': [], 'listener': []}
self._post = {'command': [], 'listener': []} self._post = {'command': [], 'listener': []}
# TODO: when handling events/commands we need to track the greenlet in
# the greenlets set so we can termiante long running commands/listeners
# on reload.
for name, member in inspect.getmembers(self, predicate=inspect.ismethod): for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'): if hasattr(member, 'meta'):
for meta in member.meta: for meta in member.meta:
@ -143,11 +159,16 @@ class Plugin(LoggingClass, PluginDeco):
elif meta['type'] == 'command': elif meta['type'] == 'command':
self.register_command(member, *meta['args'], **meta['kwargs']) self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'] == 'schedule': elif meta['type'] == 'schedule':
self.register_schedule(member, meta['interval']) self.register_schedule(member, *meta['args'], **meta['kwargs'])
elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'):
when, typ = meta['type'].split('_', 1) when, typ = meta['type'].split('_', 1)
self.register_trigger(typ, when, member) self.register_trigger(typ, when, member)
def spawn(self, method, *args, **kwargs):
obj = gevent.spawn(method, *args, **kwargs)
self.greenlets.add(obj)
return obj
def execute(self, event): def execute(self, event):
""" """
Executes a CommandEvent this plugin owns Executes a CommandEvent this plugin owns
@ -217,7 +238,7 @@ class Plugin(LoggingClass, PluginDeco):
wrapped = functools.partial(self._dispatch, 'command', func) wrapped = functools.partial(self._dispatch, 'command', func)
self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs) self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs)
def register_schedule(self, func, interval): def register_schedule(self, func, interval, repeat=True, init=True):
""" """
Registers a function to be called repeatedly, waiting for an interval Registers a function to be called repeatedly, waiting for an interval
duration. duration.
@ -230,11 +251,16 @@ class Plugin(LoggingClass, PluginDeco):
Interval (in seconds) to repeat the function on. Interval (in seconds) to repeat the function on.
""" """
def repeat(): def repeat():
while True: if init:
func() func()
while True:
gevent.sleep(interval) gevent.sleep(interval)
func()
if not repeat:
break
self.schedules[func.__name__] = gevent.spawn(repeat) self.schedules[func.__name__] = self.spawn(repeat)
def load(self): def load(self):
""" """
@ -246,6 +272,9 @@ class Plugin(LoggingClass, PluginDeco):
""" """
Called when the plugin is unloaded Called when the plugin is unloaded
""" """
for greenlet in self.greenlets:
greenlet.kill()
for listener in self.listeners: for listener in self.listeners:
listener.remove() listener.remove()
@ -254,24 +283,3 @@ class Plugin(LoggingClass, PluginDeco):
def reload(self): def reload(self):
self.bot.reload_plugin(self.__class__) self.bot.reload_plugin(self.__class__)
@staticmethod
def load_config_from_path(cls, path, format='json'):
inst = cls()
if not os.path.exists(path):
return inst
with open(path, 'r') as f:
data = f.read()
if format == 'json':
import json
inst.__dict__.update(json.loads(data))
elif format == 'yaml':
import yaml
inst.__dict__.update(yaml.load(data))
else:
raise Exception('Unsupported config format {}'.format(format))
return inst

21
disco/bot/storage.py

@ -0,0 +1,21 @@
from .backends import BACKENDS
class Storage(object):
def __init__(self, ctx, config):
self.ctx = ctx
self.backend = BACKENDS[config.backend]
# TODO: autosave
# config.autosave config.autosave_interval
@property
def guild(self):
return self.backend.base().ensure('guilds').ensure(self.ctx['guild'].id)
@property
def channel(self):
return self.backend.base().ensure('channels').ensure(self.ctx['channel'].id)
@property
def user(self):
return self.backend.base().ensure('users').ensure(self.ctx['user'].id)

8
disco/client.py

@ -85,7 +85,13 @@ class Client(object):
self.gw = GatewayClient(self, self.config.encoding_cls) self.gw = GatewayClient(self, self.config.encoding_cls)
if self.config.manhole_enable: if self.config.manhole_enable:
self.manhole_locals = {} self.manhole_locals = {
'client': self,
'state': self.state,
'api': self.api,
'gw': self.gw
}
self.manhole = DiscoBackdoorServer(self.config.manhole_bind, self.manhole = DiscoBackdoorServer(self.config.manhole_bind,
banner='Disco Manhole', banner='Disco Manhole',
localf=lambda: self.manhole_locals) localf=lambda: self.manhole_locals)

10
disco/types/base.py

@ -111,11 +111,17 @@ def datetime(data):
def text(obj): def text(obj):
return six.text_type(obj) if obj else six.text_type() if six.PY2:
return unicode(obj)
else:
return str(obj)
def binary(obj): def binary(obj):
return six.text_type(obj) if obj else six.text_type() if six.PY2:
return unicode(obj)
else:
return bytes(obj)
def field(typ, alias=None): def field(typ, alias=None):

5
disco/types/channel.py

@ -3,6 +3,7 @@ from holster.enum import Enum
from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text
from disco.types.permissions import PermissionValue from disco.types.permissions import PermissionValue
from disco.util import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.user import User from disco.types.user import User
from disco.types.permissions import Permissions, Permissible from disco.types.permissions import Permissions, Permissible
@ -241,6 +242,10 @@ class Channel(Model, Permissible):
def delete_overwrite(self, ow): def delete_overwrite(self, ow):
self.client.api.channels_permissions_delete(self.id, ow.id) self.client.api.channels_permissions_delete(self.id, ow.id)
def delete_messages_bulk(self, messages):
messages = map(to_snowflake, messages)
self.client.api.channels_messages_delete_bulk(self.id, messages)
class MessageIterator(object): class MessageIterator(object):
""" """

10
disco/types/guild.py

@ -156,6 +156,16 @@ class GuildMember(Model):
roles = self.roles + [role.id] roles = self.roles + [role.id]
self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles) self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles)
@property
def owner(self):
return self.guild.owner_id == self.id
@property
def mention(self):
if self.nick:
return '<@!{}>'.format(self.id)
return self.user.mention
@cached_property @cached_property
def guild(self): def guild(self):
return self.client.state.guilds.get(self.guild_id) return self.client.state.guilds.get(self.guild_id)

4
disco/types/user.py

@ -13,10 +13,6 @@ class User(Model):
def mention(self): def mention(self):
return '<@{}>'.format(self.id) return '<@{}>'.format(self.id)
@property
def mention_nick(self):
return '<@!{}>'.format(self.id)
def to_string(self): def to_string(self):
return '{}#{}'.format(self.username, self.discriminator) return '{}#{}'.format(self.username, self.discriminator)

42
disco/util/config.py

@ -0,0 +1,42 @@
import os
import six
from .serializer import Serializer
class Config(object):
def __init__(self, obj=None):
self.__dict__.update({
k: getattr(self, k) for k in dir(self.__class__)
})
if obj:
self.__dict__.update(obj)
@classmethod
def from_file(cls, path):
inst = cls()
with open(path, 'r') as f:
data = f.read()
_, ext = os.path.splitext(path)
Serializer.check_format(ext)
inst.__dict__.update(Serializer.load(ext, data))
return inst
def from_prefix(self, prefix):
prefix = prefix + '_'
obj = {}
for k, v in six.iteritems(self.__dict__):
if k.startswith(prefix):
obj[k[len(prefix):]] = v
return obj
def update(self, other):
if isinstance(other, Config):
other = other.__dict__
self.__dict__.update(other)

32
disco/util/serializer.py

@ -0,0 +1,32 @@
class Serializer(object):
FORMATS = {
'json',
'yaml'
}
@classmethod
def check_format(cls, fmt):
if fmt not in cls.FORMATS:
raise Exception('Unsupported serilization format: {}'.format(fmt))
@staticmethod
def json():
from json import loads, dumps
return (loads, dumps)
@staticmethod
def yaml():
from yaml import load, dump
return (load, dump)
@classmethod
def loads(cls, fmt, raw):
loads, _ = getattr(cls, fmt)()
return loads(raw)
@classmethod
def dumps(cls, fmt, raw):
_, dumps = getattr(cls, fmt)()
return dumps(raw)

12
examples/basic_plugin.py

@ -95,15 +95,15 @@ class BasicPlugin(Plugin):
json.dumps(perms.to_dict(), sort_keys=True, indent=2, separators=(',', ': ')) json.dumps(perms.to_dict(), sort_keys=True, indent=2, separators=(',', ': '))
)) ))
"""
@Plugin.command('tag', '<name:str> [value:str]') @Plugin.command('tag', '<name:str> [value:str]')
def on_tag(self, event, name, value=None): def on_tag(self, event, name, value=None):
tags = self.storage.guild.ensure('tags')
if value: if value:
self.storage.guild['tags'][name] = value tags[name] = value
event.msg.reply(':ok_hand:') event.msg.reply(':ok_hand:')
else: else:
if name in self.storage.guild['tags']: if name in tags:
return event.msg.reply(self.storage.guild['tags'][name]) return event.msg.reply(tags[name])
else: else:
event.msg.reply('Unknown tag `{}`'.format(name)) return event.msg.reply('Unknown tag: `{}`'.format(name))
"""

Loading…
Cancel
Save