Browse Source

Refactor storage module, greatly simplify

feature/storage
Andrei 8 years ago
parent
commit
a6deea55ae
  1. 1
      .gitignore
  2. 9
      disco/api/client.py
  3. 7
      disco/bot/bot.py
  4. 15
      disco/bot/providers/__init__.py
  5. 134
      disco/bot/providers/base.py
  6. 54
      disco/bot/providers/disk.py
  7. 5
      disco/bot/providers/memory.py
  8. 48
      disco/bot/providers/redis.py
  9. 52
      disco/bot/providers/rocksdb.py
  10. 97
      disco/bot/storage.py
  11. 4
      disco/types/channel.py
  12. 31
      disco/util/sanitize.py
  13. 33
      examples/storage.py

1
.gitignore

@ -3,4 +3,5 @@ dist/
disco*.egg-info/
docs/_build
storage.db
storage.json
*.dca

9
disco/api/client.py

@ -3,6 +3,7 @@ import json
from disco.api.http import Routes, HTTPClient
from disco.util.logging import LoggingClass
from disco.util.sanitize import S
from disco.types.user import User
from disco.types.message import Message
@ -88,13 +89,15 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content=None, nonce=None, tts=False, attachment=None, embed=None):
def channels_messages_create(self, channel, content=None, nonce=None, tts=False, attachment=None, embed=None, sanitize=False):
payload = {
'nonce': nonce,
'tts': tts,
}
if content:
if sanitize:
content = S(content)
payload['content'] = content
if embed:
@ -109,10 +112,12 @@ class APIClient(LoggingClass):
return Message.create(self.client, r.json())
def channels_messages_modify(self, channel, message, content=None, embed=None):
def channels_messages_modify(self, channel, message, content=None, embed=None, sanitize=False):
payload = {}
if content:
if sanitize:
content = S(content)
payload['content'] = content
if embed:

7
disco/bot/bot.py

@ -81,12 +81,13 @@ class BotConfig(Config):
commands_group_abbrev = True
plugin_config_provider = None
plugin_config_format = 'yaml'
plugin_config_format = 'json'
plugin_config_dir = 'config'
storage_enabled = True
storage_provider = 'memory'
storage_config = {}
storage_fsync = True
storage_serializer = 'json'
storage_path = 'storage.json'
class Bot(LoggingClass):

15
disco/bot/providers/__init__.py

@ -1,15 +0,0 @@
import inspect
import importlib
from .base import BaseProvider
def load_provider(name):
try:
mod = importlib.import_module('disco.bot.providers.' + name)
except ImportError:
mod = importlib.import_module(name)
for entry in filter(inspect.isclass, map(lambda i: getattr(mod, i), dir(mod))):
if issubclass(entry, BaseProvider) and entry != BaseProvider:
return entry

134
disco/bot/providers/base.py

@ -1,134 +0,0 @@
import six
import pickle
from six.moves import map, UserDict
ROOT_SENTINEL = u'\u200B'
SEP_SENTINEL = u'\u200D'
OBJ_SENTINEL = u'\u200C'
CAST_SENTINEL = u'\u24EA'
def join_key(*args):
nargs = []
for arg in args:
if not isinstance(arg, six.string_types):
arg = CAST_SENTINEL + pickle.dumps(arg)
nargs.append(arg)
return SEP_SENTINEL.join(nargs)
def true_key(key):
key = key.rsplit(SEP_SENTINEL, 1)[-1]
if key.startswith(CAST_SENTINEL):
return pickle.loads(key)
return key
class BaseProvider(object):
def __init__(self, config):
self.config = config
self.data = {}
def exists(self, key):
return key in self.data
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
for key in self.data.keys():
if key.startswith(other) and key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key in keys:
yield key, self.get(key)
def get(self, key):
return self.data[key]
def set(self, key, value):
self.data[key] = value
def delete(self, key):
del self.data[key]
def load(self):
pass
def save(self):
pass
def root(self):
return StorageDict(self)
class StorageDict(UserDict):
def __init__(self, parent_or_provider, key=None):
if isinstance(parent_or_provider, BaseProvider):
self.provider = parent_or_provider
self.parent = None
else:
self.parent = parent_or_provider
self.provider = self.parent.provider
self._key = key or ROOT_SENTINEL
def keys(self):
return map(true_key, self.provider.keys(self.key))
def values(self):
for key in self.keys():
yield self.provider.get(key)
def items(self):
for key in self.keys():
yield (true_key(key), self.provider.get(key))
def ensure(self, key, typ=dict):
if key not in self:
self[key] = typ()
return self[key]
def update(self, obj):
for k, v in six.iteritems(obj):
self[k] = v
@property
def data(self):
obj = {}
for raw, value in self.provider.get_many(self.provider.keys(self.key)):
key = true_key(raw)
if value == OBJ_SENTINEL:
value = self.__class__(self, key=key).data
obj[key] = value
return obj
@property
def key(self):
if self.parent is not None:
return join_key(self.parent.key, self._key)
return self._key
def __setitem__(self, key, value):
if isinstance(value, dict):
obj = self.__class__(self, key)
obj.update(value)
value = OBJ_SENTINEL
self.provider.set(join_key(self.key, key), value)
def __getitem__(self, key):
res = self.provider.get(join_key(self.key, key))
if res == OBJ_SENTINEL:
return self.__class__(self, key)
return res
def __delitem__(self, key):
return self.provider.delete(join_key(self.key, key))
def __contains__(self, key):
return self.provider.exists(join_key(self.key, key))

54
disco/bot/providers/disk.py

@ -1,54 +0,0 @@
import os
import gevent
from disco.util.serializer import Serializer
from .base import BaseProvider
class DiskProvider(BaseProvider):
def __init__(self, config):
super(DiskProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage') + '.' + self.format
self.fsync = config.get('fsync', False)
self.fsync_changes = config.get('fsync_changes', 1)
self.autosave_task = None
self.change_count = 0
def autosave_loop(self, interval):
while True:
gevent.sleep(interval)
self.save()
def _on_change(self):
if self.fsync:
self.change_count += 1
if self.change_count >= self.fsync_changes:
self.save()
self.change_count = 0
def load(self):
if not os.path.exists(self.path):
return
if self.config.get('autosave', True):
self.autosave_task = gevent.spawn(
self.autosave_loop,
self.config.get('autosave_interval', 120))
with open(self.path, 'r') as f:
self.data = Serializer.loads(self.format, f.read())
def save(self):
with open(self.path, 'w') as f:
f.write(Serializer.dumps(self.format, self.data))
def set(self, key, value):
super(DiskProvider, self).set(key, value)
self._on_change()
def delete(self, key):
super(DiskProvider, self).delete(key)
self._on_change()

5
disco/bot/providers/memory.py

@ -1,5 +0,0 @@
from .base import BaseProvider
class MemoryProvider(BaseProvider):
pass

48
disco/bot/providers/redis.py

@ -1,48 +0,0 @@
from __future__ import absolute_import
import redis
from itertools import izip
from disco.util.serializer import Serializer
from .base import BaseProvider, SEP_SENTINEL
class RedisProvider(BaseProvider):
def __init__(self, config):
super(RedisProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.conn = None
def load(self):
self.conn = redis.Redis(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379),
db=self.config.get('db', 0))
def exists(self, key):
return self.conn.exists(key)
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
for key in self.conn.scan_iter(u'{}*'.format(other)):
key = key.decode('utf-8')
if key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
keys = list(keys)
if not len(keys):
raise StopIteration
for key, value in izip(keys, self.conn.mget(keys)):
yield (key, Serializer.loads(self.format, value))
def get(self, key):
return Serializer.loads(self.format, self.conn.get(key))
def set(self, key, value):
self.conn.set(key, Serializer.dumps(self.format, value))
def delete(self, key):
self.conn.delete(key)

52
disco/bot/providers/rocksdb.py

@ -1,52 +0,0 @@
from __future__ import absolute_import
import six
import rocksdb
from itertools import izip
from six.moves import map
from disco.util.serializer import Serializer
from .base import BaseProvider, SEP_SENTINEL
class RocksDBProvider(BaseProvider):
def __init__(self, config):
super(RocksDBProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage.db')
self.db = None
@staticmethod
def k(k):
return bytes(k) if six.PY3 else str(k.encode('utf-8'))
def load(self):
self.db = rocksdb.DB(self.path, rocksdb.Options(create_if_missing=True))
def exists(self, key):
return self.db.get(self.k(key)) is not None
# TODO prefix extractor
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
it = self.db.iterkeys()
it.seek_to_first()
for key in it:
key = key.decode('utf-8')
if key.startswith(other) and key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key, value in izip(keys, self.db.multi_get(list(map(self.k, keys)))):
yield (key, Serializer.loads(self.format, value.decode('utf-8')))
def get(self, key):
return Serializer.loads(self.format, self.db.get(self.k(key)).decode('utf-8'))
def set(self, key, value):
self.db.put(self.k(key), Serializer.dumps(self.format, value))
def delete(self, key):
self.db.delete(self.k(key))

97
disco/bot/storage.py

@ -1,26 +1,87 @@
from .providers import load_provider
import os
from six.moves import UserDict
class Storage(object):
def __init__(self, ctx, config):
from disco.util.hashmap import HashMap
from disco.util.serializer import Serializer
class StorageHashMap(HashMap):
def __init__(self, data):
self.data = data
class ContextAwareProxy(UserDict):
def __init__(self, ctx):
self.ctx = ctx
self.config = config
self.provider = load_provider(config.provider)(config.config)
self.provider.load()
self.root = self.provider.root()
@property
def plugin(self):
return self.root.ensure('plugins').ensure(self.ctx['plugin'].name)
def data(self):
return self.ctx()
@property
def guild(self):
return self.plugin.ensure('guilds').ensure(self.ctx['guild'].id)
@property
def channel(self):
return self.plugin.ensure('channels').ensure(self.ctx['channel'].id)
class StorageDict(UserDict):
def __init__(self, parent, data):
self._parent = parent
self.data = data
@property
def user(self):
return self.plugin.ensure('users').ensure(self.ctx['user'].id)
def update(self, other):
self.data.update(other)
self._parent._update()
def __setitem__(self, key, value):
self.data[key] = value
self._parent._update()
def __delitem__(self, key):
del self.data[key]
self._parent._update()
class Storage(object):
def __init__(self, ctx, config):
self._ctx = ctx
self._path = config.path
self._serializer = config.serializer
self._fsync = config.fsync
self._data = {}
if os.path.exists(self._path):
with open(self._path, 'r') as f:
self._data = Serializer.loads(self._serializer, f.read())
def __getitem__(self, key):
if key not in self._data:
self._data[key] = {}
return StorageHashMap(StorageDict(self, self._data[key]))
def _update(self):
if self._fsync:
self.save()
def save(self):
if not self._path:
return
with open(self._path, 'w') as f:
f.write(Serializer.dumps(self._serializer, self._data))
def guild(self, key):
return ContextAwareProxy(
lambda: self['_g{}:{}'.format(self._ctx['guild'].id, key)]
)
def channel(self, key):
return ContextAwareProxy(
lambda: self['_c{}:{}'.format(self._ctx['channel'].id, key)]
)
def plugin(self, key):
return ContextAwareProxy(
lambda: self['_p{}:{}'.format(self._ctx['plugin'].name, key)]
)
def user(self, key):
return ContextAwareProxy(
lambda: self['_u{}:{}'.format(self._ctx['user'].id, key)]
)

4
disco/types/channel.py

@ -244,7 +244,7 @@ class Channel(SlottedModel, Permissible):
def create_webhook(self, name=None, avatar=None):
return self.client.api.channels_webhooks_create(self.id, name, avatar)
def send_message(self, content=None, nonce=None, tts=False, attachment=None, embed=None):
def send_message(self, *args, **kwargs):
"""
Send a message in this channel.
@ -262,7 +262,7 @@ class Channel(SlottedModel, Permissible):
:class:`disco.types.message.Message`
The created message.
"""
return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed)
return self.client.api.channels_messages_create(self.id, *args, **kwargs)
def connect(self, *args, **kwargs):
"""

31
disco/util/sanitize.py

@ -0,0 +1,31 @@
import re
# Zero width (non-rendering) space that can be used to escape mentions
ZERO_WIDTH_SPACE = u'\u200B'
# A grave-looking character that can be used to escape codeblocks
MODIFIER_GRAVE_ACCENT = u'\u02CB'
# Regex which matches all possible mention combinations, this may be over-zealous
# but its better safe than sorry.
MENTION_RE = re.compile('<[@|#][!|&]?([0-9]+)>|@everyone')
def _re_sub_mention(mention):
if '#' in mention:
return ZERO_WIDTH_SPACE.join(mention.split('#', 1))
elif '@' in mention:
return ZERO_WIDTH_SPACE.join(mention.split('@', 1))
else:
return mention
def S(text, escape_mentions=True, escape_codeblocks=False):
if escape_mentions:
text = MENTION_RE.sub(_re_sub_mention, text)
if escape_codeblocks:
text = text.replace('`', MODIFIER_GRAVE_ACCENT)
return text

33
examples/storage.py

@ -0,0 +1,33 @@
from disco.bot import Plugin
class BasicPlugin(Plugin):
def load(self, ctx):
super(BasicPlugin, self).load(ctx)
self.tags = self.storage.guild('tags')
@Plugin.command('add', '<name:str> <value:str...>', group='tags')
def on_tags_add(self, event, name, value):
if name in self.tags:
return event.msg.reply('That tag already exists!')
self.tags[name] = value
return event.msg.reply(u':ok_hand: created the tag {}'.format(name), sanitize=True)
@Plugin.command('get', '<name:str>', group='tags')
def on_tags_get(self, event, name):
if name not in self.tags:
return event.msg.reply('That tag does not exist!')
return event.msg.reply(self.tags[name], sanitize=True)
@Plugin.command('delete', '<name:str>', group='tags', aliases=['del', 'rmv', 'remove'])
def on_tags_delete(self, event, name):
if name not in self.tags:
return event.msg.reply('That tag does not exist!')
del self.tags[name]
return event.msg.reply(u':ok_hand: I deleted the {} tag for you'.format(
name
), sanitize=True)
Loading…
Cancel
Save