Browse Source

Fixes, Cleanup, Plugin Storage

The biggest part of this commit is a plugin storage subsystem, which at
this point I'm fairly happy with. I've iterated on this a couple times,
and the final result has a very clean/simple interface, is easy to
extend to different data stores, and has a very few minimal number of
grokable edge cases.

- Storage subsytem
- Fix command group abbreviations
- Fix reconnecting in the GatewaySocket
- Add pickle support to serializer
pull/5/head
Andrei 9 years ago
parent
commit
41f7126a1d
  1. 1
      .gitignore
  2. 8
      disco/bot/backends/__init__.py
  3. 20
      disco/bot/backends/base.py
  4. 35
      disco/bot/backends/disk.py
  5. 7
      disco/bot/backends/memory.py
  6. 8
      disco/bot/bot.py
  7. 6
      disco/bot/command.py
  8. 6
      disco/bot/plugin.py
  9. 15
      disco/bot/providers/__init__.py
  10. 136
      disco/bot/providers/base.py
  11. 53
      disco/bot/providers/disk.py
  12. 5
      disco/bot/providers/memory.py
  13. 50
      disco/bot/providers/rocksdb.py
  14. 19
      disco/bot/storage.py
  15. 5
      disco/gateway/client.py
  16. 5
      disco/state.py
  17. 8
      disco/util/serializer.py

1
.gitignore

@ -2,3 +2,4 @@ build/
dist/ dist/
disco*.egg-info/ disco*.egg-info/
docs/_build docs/_build
storage.db

8
disco/bot/backends/__init__.py

@ -1,8 +0,0 @@
from .memory import MemoryBackend
from .disk import DiskBackend
BACKENDS = {
'memory': MemoryBackend,
'disk': DiskBackend,
}

20
disco/bot/backends/base.py

@ -1,20 +0,0 @@
class BaseStorageBackend(object):
def base(self):
return self.storage
def __getitem__(self, key):
return self.storage[key]
def __setitem__(self, key, value):
self.storage[key] = value
def __delitem__(self, key):
del self.storage[key]
class StorageDict(dict):
def ensure(self, name):
if not dict.__contains__(self, name):
dict.__setitem__(self, name, StorageDict())
return dict.__getitem__(self, name)

35
disco/bot/backends/disk.py

@ -1,35 +0,0 @@
import os
from .base import BaseStorageBackend, StorageDict
class DiskBackend(BaseStorageBackend):
def __init__(self, config):
self.format = config.get('format', 'json')
self.path = config.get('path', 'storage') + '.' + self.format
self.storage = StorageDict()
@staticmethod
def get_format_functions(fmt):
if fmt == 'json':
from json import loads, dumps
return (loads, dumps)
elif fmt == 'yaml':
from pyyaml import load, dump
return (load, dump)
raise Exception('Unsupported format type {}'.format(fmt))
def load(self):
if not os.path.exists(self.path):
return
decode, _ = self.get_format_functions(self.format)
with open(self.path, 'r') as f:
self.storage = decode(f.read())
def dump(self):
_, encode = self.get_format_functions(self.format)
with open(self.path, 'w') as f:
f.write(encode(self.storage))

7
disco/bot/backends/memory.py

@ -1,7 +0,0 @@
from .base import BaseStorageBackend, StorageDict
class MemoryBackend(BaseStorageBackend):
def __init__(self, config):
self.storage = StorageDict()

8
disco/bot/bot.py

@ -84,9 +84,8 @@ class BotConfig(Config):
plugin_config_dir = 'config' plugin_config_dir = 'config'
storage_enabled = False storage_enabled = False
storage_backend = 'memory' storage_provider = 'memory'
storage_autosave = True storage_config = {}
storage_autosave_interval = 120
class Bot(object): class Bot(object):
@ -184,8 +183,9 @@ class Bot(object):
""" """
Called when a plugin is loaded/unloaded to recompute internal state. Called when a plugin is loaded/unloaded to recompute internal state.
""" """
self.compute_group_abbrev()
if self.config.commands_group_abbrev: if self.config.commands_group_abbrev:
self.compute_group_abbrev()
self.compute_command_matches_re() self.compute_command_matches_re()
def compute_group_abbrev(self): def compute_group_abbrev(self):

6
disco/bot/command.py

@ -165,10 +165,10 @@ class Command(object):
else: else:
group = '' group = ''
if self.group: if self.group:
if self.group in self.plugin.bot.group_abbrev.get(self.group): if self.group in self.plugin.bot.group_abbrev:
group = '{}(?:\w+)? '.format(self.group) group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group))
else: else:
group = self.group group = self.group + ' '
return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX)
def execute(self, event): def execute(self, event):

6
disco/bot/plugin.py

@ -17,6 +17,7 @@ class PluginDeco(object):
""" """
Prio = Priority Prio = Priority
# TODO: dont smash class methods
@staticmethod @staticmethod
def add_meta_deco(meta): def add_meta_deco(meta):
def deco(f): def deco(f):
@ -152,6 +153,10 @@ class Plugin(LoggingClass, PluginDeco):
self.storage = bot.storage self.storage = bot.storage
self.config = config self.config = config
@property
def name(self):
return self.__class__.__name__
def bind_all(self): def bind_all(self):
self.listeners = [] self.listeners = []
self.commands = {} self.commands = {}
@ -188,6 +193,7 @@ class Plugin(LoggingClass, PluginDeco):
""" """
Executes a CommandEvent this plugin owns Executes a CommandEvent this plugin owns
""" """
self.ctx['plugin'] = self
self.ctx['guild'] = event.guild self.ctx['guild'] = event.guild
self.ctx['channel'] = event.channel self.ctx['channel'] = event.channel
self.ctx['user'] = event.author self.ctx['user'] = event.author

15
disco/bot/providers/__init__.py

@ -0,0 +1,15 @@
import inspect
import importlib
from .base import BaseProvider
def load_provider(name):
try:
mod = importlib.import_module('disco.bot.providers.' + name)
except ImportError:
mod = importlib.import_module(name)
for entry in filter(inspect.isclass, map(lambda i: getattr(mod, i), dir(mod))):
if issubclass(entry, BaseProvider) and entry != BaseProvider:
return entry

136
disco/bot/providers/base.py

@ -0,0 +1,136 @@
import six
import pickle
from six.moves import map
from UserDict import UserDict
ROOT_SENTINEL = u'\u200B'
SEP_SENTINEL = u'\u200D'
OBJ_SENTINEL = u'\u200C'
CAST_SENTINEL = u'\u24EA'
def join_key(*args):
nargs = []
for arg in args:
if not isinstance(arg, six.string_types):
arg = CAST_SENTINEL + pickle.dumps(arg)
nargs.append(arg)
return SEP_SENTINEL.join(nargs)
def true_key(key):
key = key.rsplit(SEP_SENTINEL, 1)[-1]
if key.startswith(CAST_SENTINEL):
return pickle.loads(key)
return key
class BaseProvider(object):
def __init__(self, config):
self.config = config
self.data = {}
def exists(self, key):
return key in self.data
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
for key in self.data.keys():
if key.startswith(other) and key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key in keys:
yield key, self.get(key)
def get(self, key):
return self.data[key]
def set(self, key, value):
self.data[key] = value
def delete(self, key):
del self.data[key]
def load(self):
pass
def save(self):
pass
def root(self):
return StorageDict(self)
class StorageDict(UserDict):
def __init__(self, parent_or_provider, key=None):
if isinstance(parent_or_provider, BaseProvider):
self.provider = parent_or_provider
self.parent = None
else:
self.parent = parent_or_provider
self.provider = self.parent.provider
self._key = key or ROOT_SENTINEL
def keys(self):
return map(true_key, self.provider.keys(self.key))
def values(self):
for key in self.keys():
yield self.provider.get(key)
def items(self):
for key in self.keys():
yield (true_key(key), self.provider.get(key))
def ensure(self, key, typ=dict):
if key not in self:
self[key] = typ()
return self[key]
def update(self, obj):
for k, v in six.iteritems(obj):
self[k] = v
@property
def data(self):
obj = {}
for raw, value in self.provider.get_many(self.provider.keys(self.key)):
key = true_key(raw)
if value == OBJ_SENTINEL:
value = self.__class__(self, key=key).data
obj[key] = value
return obj
@property
def key(self):
if self.parent is not None:
return join_key(self.parent.key, self._key)
return self._key
def __setitem__(self, key, value):
if isinstance(value, dict):
obj = self.__class__(self, key)
obj.update(value)
value = OBJ_SENTINEL
self.provider.set(join_key(self.key, key), value)
def __getitem__(self, key):
res = self.provider.get(join_key(self.key, key))
if res == OBJ_SENTINEL:
return self.__class__(self, key)
return res
def __delitem__(self, key):
return self.provider.delete(join_key(self.key, key))
def __contains__(self, key):
return self.provider.exists(join_key(self.key, key))

53
disco/bot/providers/disk.py

@ -0,0 +1,53 @@
import os
import gevent
from disco.util.serializer import Serializer
from .base import BaseProvider
class DiskProvider(BaseProvider):
def __init__(self, config):
super(DiskProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage') + '.' + self.format
self.fsync = config.get('fsync', False)
self.fsync_changes = config.get('fsync_changes', 1)
self.change_count = 0
def autosave_loop(self, interval):
while True:
gevent.sleep(interval)
self.save()
def _on_change(self):
if self.fsync:
self.change_count += 1
if self.change_count >= self.fsync_changes:
self.save()
self.change_count = 0
def load(self):
if not os.path.exists(self.path):
return
if self.config.get('autosave', True):
self.autosave_task = gevent.spawn(
self.autosave_loop,
self.config.get('autosave_interval', 120))
with open(self.path, 'r') as f:
self.data = Serializer.loads(self.format, f.read())
def save(self):
with open(self.path, 'w') as f:
f.write(Serializer.dumps(self.format, self.data))
def set(self, key, value):
super(DiskProvider, self).set(key, value)
self._on_change()
def delete(self, key):
super(DiskProvider, self).delete(key)
self._on_change()

5
disco/bot/providers/memory.py

@ -0,0 +1,5 @@
from .base import BaseProvider
class MemoryProvider(BaseProvider):
pass

50
disco/bot/providers/rocksdb.py

@ -0,0 +1,50 @@
from __future__ import absolute_import
import six
import rocksdb
from itertools import izip
from six.moves import map
from disco.util.serializer import Serializer
from .base import BaseProvider, SEP_SENTINEL
class RocksDBProvider(BaseProvider):
def __init__(self, config):
self.config = config
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage.db')
def k(self, k):
return bytes(k) if six.PY3 else str(k.encode('utf-8'))
def load(self):
self.db = rocksdb.DB(self.path, rocksdb.Options(create_if_missing=True))
def exists(self, key):
return self.db.get(self.k(key)) is not None
# TODO prefix extractor
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
it = self.db.iterkeys()
it.seek_to_first()
for key in it:
key = key.decode('utf-8')
if key.startswith(other) and key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key, value in izip(keys, self.db.multi_get(list(map(self.k, keys)))):
yield (key, Serializer.loads(self.format, value.decode('utf-8')))
def get(self, key):
return Serializer.loads(self.format, self.db.get(self.k(key)).decode('utf-8'))
def set(self, key, value):
self.db.put(self.k(key), Serializer.dumps(self.format, value))
def delete(self, key):
self.db.delete(self.k(key))

19
disco/bot/storage.py

@ -1,21 +1,26 @@
from .backends import BACKENDS from .providers import load_provider
class Storage(object): class Storage(object):
def __init__(self, ctx, config): def __init__(self, ctx, config):
self.ctx = ctx self.ctx = ctx
self.backend = BACKENDS[config.backend] self.config = config
# TODO: autosave self.provider = load_provider(config.provider)(config.config)
# config.autosave config.autosave_interval self.provider.load()
self.root = self.provider.root()
@property
def plugin(self):
return self.root.ensure('plugins').ensure(self.ctx['plugin'].name)
@property @property
def guild(self): def guild(self):
return self.backend.base().ensure('guilds').ensure(self.ctx['guild'].id) return self.plugin.ensure('guilds').ensure(self.ctx['guild'].id)
@property @property
def channel(self): def channel(self):
return self.backend.base().ensure('channels').ensure(self.ctx['channel'].id) return self.plugin.ensure('channels').ensure(self.ctx['channel'].id)
@property @property
def user(self): def user(self):
return self.backend.base().ensure('users').ensure(self.ctx['user'].id) return self.plugin.ensure('users').ensure(self.ctx['user'].id)

5
disco/gateway/client.py

@ -36,6 +36,7 @@ class GatewayClient(LoggingClass):
# Websocket connection # Websocket connection
self.ws = None self.ws = None
self.ws_event = gevent.event.Event()
# State # State
self.seq = 0 self.seq = 0
@ -125,6 +126,7 @@ class GatewayClient(LoggingClass):
def on_error(self, error): def on_error(self, error):
if isinstance(error, KeyboardInterrupt): if isinstance(error, KeyboardInterrupt):
self.shutting_down = True self.shutting_down = True
self.ws_event.set()
raise Exception('WS recieved error: %s', error) raise Exception('WS recieved error: %s', error)
def on_open(self): def on_open(self):
@ -176,4 +178,5 @@ class GatewayClient(LoggingClass):
self.connect_and_run() self.connect_and_run()
def run(self): def run(self):
self.connect_and_run() gevent.spawn(self.connect_and_run)
self.ws_event.wait()

5
disco/state.py

@ -98,7 +98,7 @@ class State(object):
# If message tracking is enabled, listen to those events # If message tracking is enabled, listen to those events
if self.config.track_messages: if self.config.track_messages:
self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size))
self.EVENTS += ['MessageDelete'] self.EVENTS += ['MessageDelete', 'MessageDeleteBulk']
# The bound listener objects # The bound listener objects
self.listeners = [] self.listeners = []
@ -152,7 +152,8 @@ class State(object):
if event.channel_id not in self.messages: if event.channel_id not in self.messages:
return return
for sm in self.messages[event.channel_id]: # TODO: performance
for sm in list(self.messages[event.channel_id]):
if sm.id in event.ids: if sm.id in event.ids:
self.messages[event.channel_id].remove(sm) self.messages[event.channel_id].remove(sm)

8
disco/util/serializer.py

@ -3,7 +3,8 @@
class Serializer(object): class Serializer(object):
FORMATS = { FORMATS = {
'json', 'json',
'yaml' 'yaml',
'pickle',
} }
@classmethod @classmethod
@ -21,6 +22,11 @@ class Serializer(object):
from yaml import load, dump from yaml import load, dump
return (load, dump) return (load, dump)
@staticmethod
def pickle():
from pickle import loads, dumps
return (loads, dumps)
@classmethod @classmethod
def loads(cls, fmt, raw): def loads(cls, fmt, raw):
loads, _ = getattr(cls, fmt)() loads, _ = getattr(cls, fmt)()

Loading…
Cancel
Save