diff --git a/.gitignore b/.gitignore index 8d18da1..54cb974 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/ dist/ disco*.egg-info/ docs/_build +storage.db diff --git a/disco/bot/backends/__init__.py b/disco/bot/backends/__init__.py deleted file mode 100644 index b8df75b..0000000 --- a/disco/bot/backends/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .memory import MemoryBackend -from .disk import DiskBackend - - -BACKENDS = { - 'memory': MemoryBackend, - 'disk': DiskBackend, -} diff --git a/disco/bot/backends/base.py b/disco/bot/backends/base.py deleted file mode 100644 index fe005f8..0000000 --- a/disco/bot/backends/base.py +++ /dev/null @@ -1,20 +0,0 @@ - -class BaseStorageBackend(object): - def base(self): - return self.storage - - def __getitem__(self, key): - return self.storage[key] - - def __setitem__(self, key, value): - self.storage[key] = value - - def __delitem__(self, key): - del self.storage[key] - - -class StorageDict(dict): - def ensure(self, name): - if not dict.__contains__(self, name): - dict.__setitem__(self, name, StorageDict()) - return dict.__getitem__(self, name) diff --git a/disco/bot/backends/disk.py b/disco/bot/backends/disk.py deleted file mode 100644 index 7de40b7..0000000 --- a/disco/bot/backends/disk.py +++ /dev/null @@ -1,35 +0,0 @@ -import os - -from .base import BaseStorageBackend, StorageDict - - -class DiskBackend(BaseStorageBackend): - def __init__(self, config): - self.format = config.get('format', 'json') - self.path = config.get('path', 'storage') + '.' + self.format - self.storage = StorageDict() - - @staticmethod - def get_format_functions(fmt): - if fmt == 'json': - from json import loads, dumps - return (loads, dumps) - elif fmt == 'yaml': - from pyyaml import load, dump - return (load, dump) - raise Exception('Unsupported format type {}'.format(fmt)) - - def load(self): - if not os.path.exists(self.path): - return - - decode, _ = self.get_format_functions(self.format) - - with open(self.path, 'r') as f: - self.storage = decode(f.read()) - - def dump(self): - _, encode = self.get_format_functions(self.format) - - with open(self.path, 'w') as f: - f.write(encode(self.storage)) diff --git a/disco/bot/backends/memory.py b/disco/bot/backends/memory.py deleted file mode 100644 index 8e26ca2..0000000 --- a/disco/bot/backends/memory.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base import BaseStorageBackend, StorageDict - - -class MemoryBackend(BaseStorageBackend): - def __init__(self, config): - self.storage = StorageDict() - diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 215bdb6..1d14a91 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -84,9 +84,8 @@ class BotConfig(Config): plugin_config_dir = 'config' storage_enabled = False - storage_backend = 'memory' - storage_autosave = True - storage_autosave_interval = 120 + storage_provider = 'memory' + storage_config = {} class Bot(object): @@ -184,9 +183,10 @@ class Bot(object): """ Called when a plugin is loaded/unloaded to recompute internal state. """ - self.compute_group_abbrev() if self.config.commands_group_abbrev: - self.compute_command_matches_re() + self.compute_group_abbrev() + + self.compute_command_matches_re() def compute_group_abbrev(self): """ diff --git a/disco/bot/command.py b/disco/bot/command.py index d851c8b..2546c9d 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -165,10 +165,10 @@ class Command(object): else: group = '' if self.group: - if self.group in self.plugin.bot.group_abbrev.get(self.group): - group = '{}(?:\w+)? '.format(self.group) + if self.group in self.plugin.bot.group_abbrev: + group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group)) else: - group = self.group + group = self.group + ' ' return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) def execute(self, event): diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a8585f7..89489b9 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -17,6 +17,7 @@ class PluginDeco(object): """ Prio = Priority + # TODO: dont smash class methods @staticmethod def add_meta_deco(meta): def deco(f): @@ -152,6 +153,10 @@ class Plugin(LoggingClass, PluginDeco): self.storage = bot.storage self.config = config + @property + def name(self): + return self.__class__.__name__ + def bind_all(self): self.listeners = [] self.commands = {} @@ -188,6 +193,7 @@ class Plugin(LoggingClass, PluginDeco): """ Executes a CommandEvent this plugin owns """ + self.ctx['plugin'] = self self.ctx['guild'] = event.guild self.ctx['channel'] = event.channel self.ctx['user'] = event.author diff --git a/disco/bot/providers/__init__.py b/disco/bot/providers/__init__.py new file mode 100644 index 0000000..3ec985f --- /dev/null +++ b/disco/bot/providers/__init__.py @@ -0,0 +1,15 @@ +import inspect +import importlib + +from .base import BaseProvider + + +def load_provider(name): + try: + mod = importlib.import_module('disco.bot.providers.' + name) + except ImportError: + mod = importlib.import_module(name) + + for entry in filter(inspect.isclass, map(lambda i: getattr(mod, i), dir(mod))): + if issubclass(entry, BaseProvider) and entry != BaseProvider: + return entry diff --git a/disco/bot/providers/base.py b/disco/bot/providers/base.py new file mode 100644 index 0000000..033b020 --- /dev/null +++ b/disco/bot/providers/base.py @@ -0,0 +1,136 @@ +import six +import pickle + +from six.moves import map + +from UserDict import UserDict + + +ROOT_SENTINEL = u'\u200B' +SEP_SENTINEL = u'\u200D' +OBJ_SENTINEL = u'\u200C' +CAST_SENTINEL = u'\u24EA' + + +def join_key(*args): + nargs = [] + for arg in args: + if not isinstance(arg, six.string_types): + arg = CAST_SENTINEL + pickle.dumps(arg) + nargs.append(arg) + return SEP_SENTINEL.join(nargs) + + +def true_key(key): + key = key.rsplit(SEP_SENTINEL, 1)[-1] + if key.startswith(CAST_SENTINEL): + return pickle.loads(key) + return key + + +class BaseProvider(object): + def __init__(self, config): + self.config = config + self.data = {} + + def exists(self, key): + return key in self.data + + def keys(self, other): + count = other.count(SEP_SENTINEL) + 1 + for key in self.data.keys(): + if key.startswith(other) and key.count(SEP_SENTINEL) == count: + yield key + + def get_many(self, keys): + for key in keys: + yield key, self.get(key) + + def get(self, key): + return self.data[key] + + def set(self, key, value): + self.data[key] = value + + def delete(self, key): + del self.data[key] + + def load(self): + pass + + def save(self): + pass + + def root(self): + return StorageDict(self) + + +class StorageDict(UserDict): + def __init__(self, parent_or_provider, key=None): + if isinstance(parent_or_provider, BaseProvider): + self.provider = parent_or_provider + self.parent = None + else: + self.parent = parent_or_provider + self.provider = self.parent.provider + self._key = key or ROOT_SENTINEL + + def keys(self): + return map(true_key, self.provider.keys(self.key)) + + def values(self): + for key in self.keys(): + yield self.provider.get(key) + + def items(self): + for key in self.keys(): + yield (true_key(key), self.provider.get(key)) + + def ensure(self, key, typ=dict): + if key not in self: + self[key] = typ() + return self[key] + + def update(self, obj): + for k, v in six.iteritems(obj): + self[k] = v + + @property + def data(self): + obj = {} + + for raw, value in self.provider.get_many(self.provider.keys(self.key)): + key = true_key(raw) + + if value == OBJ_SENTINEL: + value = self.__class__(self, key=key).data + obj[key] = value + return obj + + @property + def key(self): + if self.parent is not None: + return join_key(self.parent.key, self._key) + return self._key + + def __setitem__(self, key, value): + if isinstance(value, dict): + obj = self.__class__(self, key) + obj.update(value) + value = OBJ_SENTINEL + + self.provider.set(join_key(self.key, key), value) + + def __getitem__(self, key): + res = self.provider.get(join_key(self.key, key)) + + if res == OBJ_SENTINEL: + return self.__class__(self, key) + + return res + + def __delitem__(self, key): + return self.provider.delete(join_key(self.key, key)) + + def __contains__(self, key): + return self.provider.exists(join_key(self.key, key)) diff --git a/disco/bot/providers/disk.py b/disco/bot/providers/disk.py new file mode 100644 index 0000000..5cf1ca3 --- /dev/null +++ b/disco/bot/providers/disk.py @@ -0,0 +1,53 @@ +import os +import gevent + +from disco.util.serializer import Serializer +from .base import BaseProvider + + +class DiskProvider(BaseProvider): + def __init__(self, config): + super(DiskProvider, self).__init__(config) + self.format = config.get('format', 'pickle') + self.path = config.get('path', 'storage') + '.' + self.format + self.fsync = config.get('fsync', False) + self.fsync_changes = config.get('fsync_changes', 1) + + self.change_count = 0 + + def autosave_loop(self, interval): + while True: + gevent.sleep(interval) + self.save() + + def _on_change(self): + if self.fsync: + self.change_count += 1 + + if self.change_count >= self.fsync_changes: + self.save() + self.change_count = 0 + + def load(self): + if not os.path.exists(self.path): + return + + if self.config.get('autosave', True): + self.autosave_task = gevent.spawn( + self.autosave_loop, + self.config.get('autosave_interval', 120)) + + with open(self.path, 'r') as f: + self.data = Serializer.loads(self.format, f.read()) + + def save(self): + with open(self.path, 'w') as f: + f.write(Serializer.dumps(self.format, self.data)) + + def set(self, key, value): + super(DiskProvider, self).set(key, value) + self._on_change() + + def delete(self, key): + super(DiskProvider, self).delete(key) + self._on_change() diff --git a/disco/bot/providers/memory.py b/disco/bot/providers/memory.py new file mode 100644 index 0000000..17ad47b --- /dev/null +++ b/disco/bot/providers/memory.py @@ -0,0 +1,5 @@ +from .base import BaseProvider + + +class MemoryProvider(BaseProvider): + pass diff --git a/disco/bot/providers/rocksdb.py b/disco/bot/providers/rocksdb.py new file mode 100644 index 0000000..0062d79 --- /dev/null +++ b/disco/bot/providers/rocksdb.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import + +import six +import rocksdb + +from itertools import izip +from six.moves import map + +from disco.util.serializer import Serializer +from .base import BaseProvider, SEP_SENTINEL + + +class RocksDBProvider(BaseProvider): + def __init__(self, config): + self.config = config + self.format = config.get('format', 'pickle') + self.path = config.get('path', 'storage.db') + + def k(self, k): + return bytes(k) if six.PY3 else str(k.encode('utf-8')) + + def load(self): + self.db = rocksdb.DB(self.path, rocksdb.Options(create_if_missing=True)) + + def exists(self, key): + return self.db.get(self.k(key)) is not None + + # TODO prefix extractor + def keys(self, other): + count = other.count(SEP_SENTINEL) + 1 + it = self.db.iterkeys() + it.seek_to_first() + + for key in it: + key = key.decode('utf-8') + if key.startswith(other) and key.count(SEP_SENTINEL) == count: + yield key + + def get_many(self, keys): + for key, value in izip(keys, self.db.multi_get(list(map(self.k, keys)))): + yield (key, Serializer.loads(self.format, value.decode('utf-8'))) + + def get(self, key): + return Serializer.loads(self.format, self.db.get(self.k(key)).decode('utf-8')) + + def set(self, key, value): + self.db.put(self.k(key), Serializer.dumps(self.format, value)) + + def delete(self, key): + self.db.delete(self.k(key)) diff --git a/disco/bot/storage.py b/disco/bot/storage.py index 45fa1f6..812d79c 100644 --- a/disco/bot/storage.py +++ b/disco/bot/storage.py @@ -1,21 +1,26 @@ -from .backends import BACKENDS +from .providers import load_provider class Storage(object): def __init__(self, ctx, config): self.ctx = ctx - self.backend = BACKENDS[config.backend] - # TODO: autosave - # config.autosave config.autosave_interval + self.config = config + self.provider = load_provider(config.provider)(config.config) + self.provider.load() + self.root = self.provider.root() + + @property + def plugin(self): + return self.root.ensure('plugins').ensure(self.ctx['plugin'].name) @property def guild(self): - return self.backend.base().ensure('guilds').ensure(self.ctx['guild'].id) + return self.plugin.ensure('guilds').ensure(self.ctx['guild'].id) @property def channel(self): - return self.backend.base().ensure('channels').ensure(self.ctx['channel'].id) + return self.plugin.ensure('channels').ensure(self.ctx['channel'].id) @property def user(self): - return self.backend.base().ensure('users').ensure(self.ctx['user'].id) + return self.plugin.ensure('users').ensure(self.ctx['user'].id) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 6c88d9f..ac7ea79 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -36,6 +36,7 @@ class GatewayClient(LoggingClass): # Websocket connection self.ws = None + self.ws_event = gevent.event.Event() # State self.seq = 0 @@ -125,6 +126,7 @@ class GatewayClient(LoggingClass): def on_error(self, error): if isinstance(error, KeyboardInterrupt): self.shutting_down = True + self.ws_event.set() raise Exception('WS recieved error: %s', error) def on_open(self): @@ -176,4 +178,5 @@ class GatewayClient(LoggingClass): self.connect_and_run() def run(self): - self.connect_and_run() + gevent.spawn(self.connect_and_run) + self.ws_event.wait() diff --git a/disco/state.py b/disco/state.py index efa883b..abd9fe2 100644 --- a/disco/state.py +++ b/disco/state.py @@ -98,7 +98,7 @@ class State(object): # If message tracking is enabled, listen to those events if self.config.track_messages: self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) - self.EVENTS += ['MessageDelete'] + self.EVENTS += ['MessageDelete', 'MessageDeleteBulk'] # The bound listener objects self.listeners = [] @@ -152,7 +152,8 @@ class State(object): if event.channel_id not in self.messages: return - for sm in self.messages[event.channel_id]: + # TODO: performance + for sm in list(self.messages[event.channel_id]): if sm.id in event.ids: self.messages[event.channel_id].remove(sm) diff --git a/disco/util/serializer.py b/disco/util/serializer.py index 565d513..74fe766 100644 --- a/disco/util/serializer.py +++ b/disco/util/serializer.py @@ -3,7 +3,8 @@ class Serializer(object): FORMATS = { 'json', - 'yaml' + 'yaml', + 'pickle', } @classmethod @@ -21,6 +22,11 @@ class Serializer(object): from yaml import load, dump return (load, dump) + @staticmethod + def pickle(): + from pickle import loads, dumps + return (loads, dumps) + @classmethod def loads(cls, fmt, raw): loads, _ = getattr(cls, fmt)()