diff --git a/.gitignore b/.gitignore index 87a6c02..1346a5f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ dist/ disco*.egg-info/ docs/_build storage.db +storage.json *.dca diff --git a/disco/api/client.py b/disco/api/client.py index e1adee7..448d2f7 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -3,6 +3,7 @@ import json from disco.api.http import Routes, HTTPClient from disco.util.logging import LoggingClass +from disco.util.sanitize import S from disco.types.user import User from disco.types.message import Message @@ -88,13 +89,15 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) return Message.create(self.client, r.json()) - def channels_messages_create(self, channel, content=None, nonce=None, tts=False, attachment=None, embed=None): + def channels_messages_create(self, channel, content=None, nonce=None, tts=False, attachment=None, embed=None, sanitize=False): payload = { 'nonce': nonce, 'tts': tts, } if content: + if sanitize: + content = S(content) payload['content'] = content if embed: @@ -109,10 +112,12 @@ class APIClient(LoggingClass): return Message.create(self.client, r.json()) - def channels_messages_modify(self, channel, message, content=None, embed=None): + def channels_messages_modify(self, channel, message, content=None, embed=None, sanitize=False): payload = {} if content: + if sanitize: + content = S(content) payload['content'] = content if embed: diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 2631cb0..e7ccad4 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -81,12 +81,13 @@ class BotConfig(Config): commands_group_abbrev = True plugin_config_provider = None - plugin_config_format = 'yaml' + plugin_config_format = 'json' plugin_config_dir = 'config' storage_enabled = True - storage_provider = 'memory' - storage_config = {} + storage_fsync = True + storage_serializer = 'json' + storage_path = 'storage.json' class Bot(LoggingClass): diff --git a/disco/bot/providers/__init__.py b/disco/bot/providers/__init__.py deleted file mode 100644 index 3ec985f..0000000 --- a/disco/bot/providers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import inspect -import importlib - -from .base import BaseProvider - - -def load_provider(name): - try: - mod = importlib.import_module('disco.bot.providers.' + name) - except ImportError: - mod = importlib.import_module(name) - - for entry in filter(inspect.isclass, map(lambda i: getattr(mod, i), dir(mod))): - if issubclass(entry, BaseProvider) and entry != BaseProvider: - return entry diff --git a/disco/bot/providers/base.py b/disco/bot/providers/base.py deleted file mode 100644 index 0f14f3d..0000000 --- a/disco/bot/providers/base.py +++ /dev/null @@ -1,134 +0,0 @@ -import six -import pickle - -from six.moves import map, UserDict - - -ROOT_SENTINEL = u'\u200B' -SEP_SENTINEL = u'\u200D' -OBJ_SENTINEL = u'\u200C' -CAST_SENTINEL = u'\u24EA' - - -def join_key(*args): - nargs = [] - for arg in args: - if not isinstance(arg, six.string_types): - arg = CAST_SENTINEL + pickle.dumps(arg) - nargs.append(arg) - return SEP_SENTINEL.join(nargs) - - -def true_key(key): - key = key.rsplit(SEP_SENTINEL, 1)[-1] - if key.startswith(CAST_SENTINEL): - return pickle.loads(key) - return key - - -class BaseProvider(object): - def __init__(self, config): - self.config = config - self.data = {} - - def exists(self, key): - return key in self.data - - def keys(self, other): - count = other.count(SEP_SENTINEL) + 1 - for key in self.data.keys(): - if key.startswith(other) and key.count(SEP_SENTINEL) == count: - yield key - - def get_many(self, keys): - for key in keys: - yield key, self.get(key) - - def get(self, key): - return self.data[key] - - def set(self, key, value): - self.data[key] = value - - def delete(self, key): - del self.data[key] - - def load(self): - pass - - def save(self): - pass - - def root(self): - return StorageDict(self) - - -class StorageDict(UserDict): - def __init__(self, parent_or_provider, key=None): - if isinstance(parent_or_provider, BaseProvider): - self.provider = parent_or_provider - self.parent = None - else: - self.parent = parent_or_provider - self.provider = self.parent.provider - self._key = key or ROOT_SENTINEL - - def keys(self): - return map(true_key, self.provider.keys(self.key)) - - def values(self): - for key in self.keys(): - yield self.provider.get(key) - - def items(self): - for key in self.keys(): - yield (true_key(key), self.provider.get(key)) - - def ensure(self, key, typ=dict): - if key not in self: - self[key] = typ() - return self[key] - - def update(self, obj): - for k, v in six.iteritems(obj): - self[k] = v - - @property - def data(self): - obj = {} - - for raw, value in self.provider.get_many(self.provider.keys(self.key)): - key = true_key(raw) - - if value == OBJ_SENTINEL: - value = self.__class__(self, key=key).data - obj[key] = value - return obj - - @property - def key(self): - if self.parent is not None: - return join_key(self.parent.key, self._key) - return self._key - - def __setitem__(self, key, value): - if isinstance(value, dict): - obj = self.__class__(self, key) - obj.update(value) - value = OBJ_SENTINEL - - self.provider.set(join_key(self.key, key), value) - - def __getitem__(self, key): - res = self.provider.get(join_key(self.key, key)) - - if res == OBJ_SENTINEL: - return self.__class__(self, key) - - return res - - def __delitem__(self, key): - return self.provider.delete(join_key(self.key, key)) - - def __contains__(self, key): - return self.provider.exists(join_key(self.key, key)) diff --git a/disco/bot/providers/disk.py b/disco/bot/providers/disk.py deleted file mode 100644 index af259e1..0000000 --- a/disco/bot/providers/disk.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -import gevent - -from disco.util.serializer import Serializer -from .base import BaseProvider - - -class DiskProvider(BaseProvider): - def __init__(self, config): - super(DiskProvider, self).__init__(config) - self.format = config.get('format', 'pickle') - self.path = config.get('path', 'storage') + '.' + self.format - self.fsync = config.get('fsync', False) - self.fsync_changes = config.get('fsync_changes', 1) - - self.autosave_task = None - self.change_count = 0 - - def autosave_loop(self, interval): - while True: - gevent.sleep(interval) - self.save() - - def _on_change(self): - if self.fsync: - self.change_count += 1 - - if self.change_count >= self.fsync_changes: - self.save() - self.change_count = 0 - - def load(self): - if not os.path.exists(self.path): - return - - if self.config.get('autosave', True): - self.autosave_task = gevent.spawn( - self.autosave_loop, - self.config.get('autosave_interval', 120)) - - with open(self.path, 'r') as f: - self.data = Serializer.loads(self.format, f.read()) - - def save(self): - with open(self.path, 'w') as f: - f.write(Serializer.dumps(self.format, self.data)) - - def set(self, key, value): - super(DiskProvider, self).set(key, value) - self._on_change() - - def delete(self, key): - super(DiskProvider, self).delete(key) - self._on_change() diff --git a/disco/bot/providers/memory.py b/disco/bot/providers/memory.py deleted file mode 100644 index 17ad47b..0000000 --- a/disco/bot/providers/memory.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base import BaseProvider - - -class MemoryProvider(BaseProvider): - pass diff --git a/disco/bot/providers/redis.py b/disco/bot/providers/redis.py deleted file mode 100644 index f5e1375..0000000 --- a/disco/bot/providers/redis.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import absolute_import - -import redis - -from itertools import izip - -from disco.util.serializer import Serializer -from .base import BaseProvider, SEP_SENTINEL - - -class RedisProvider(BaseProvider): - def __init__(self, config): - super(RedisProvider, self).__init__(config) - self.format = config.get('format', 'pickle') - self.conn = None - - def load(self): - self.conn = redis.Redis( - host=self.config.get('host', 'localhost'), - port=self.config.get('port', 6379), - db=self.config.get('db', 0)) - - def exists(self, key): - return self.conn.exists(key) - - def keys(self, other): - count = other.count(SEP_SENTINEL) + 1 - for key in self.conn.scan_iter(u'{}*'.format(other)): - key = key.decode('utf-8') - if key.count(SEP_SENTINEL) == count: - yield key - - def get_many(self, keys): - keys = list(keys) - if not len(keys): - raise StopIteration - - for key, value in izip(keys, self.conn.mget(keys)): - yield (key, Serializer.loads(self.format, value)) - - def get(self, key): - return Serializer.loads(self.format, self.conn.get(key)) - - def set(self, key, value): - self.conn.set(key, Serializer.dumps(self.format, value)) - - def delete(self, key): - self.conn.delete(key) diff --git a/disco/bot/providers/rocksdb.py b/disco/bot/providers/rocksdb.py deleted file mode 100644 index 986268d..0000000 --- a/disco/bot/providers/rocksdb.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import - -import six -import rocksdb - -from itertools import izip -from six.moves import map - -from disco.util.serializer import Serializer -from .base import BaseProvider, SEP_SENTINEL - - -class RocksDBProvider(BaseProvider): - def __init__(self, config): - super(RocksDBProvider, self).__init__(config) - self.format = config.get('format', 'pickle') - self.path = config.get('path', 'storage.db') - self.db = None - - @staticmethod - def k(k): - return bytes(k) if six.PY3 else str(k.encode('utf-8')) - - def load(self): - self.db = rocksdb.DB(self.path, rocksdb.Options(create_if_missing=True)) - - def exists(self, key): - return self.db.get(self.k(key)) is not None - - # TODO prefix extractor - def keys(self, other): - count = other.count(SEP_SENTINEL) + 1 - it = self.db.iterkeys() - it.seek_to_first() - - for key in it: - key = key.decode('utf-8') - if key.startswith(other) and key.count(SEP_SENTINEL) == count: - yield key - - def get_many(self, keys): - for key, value in izip(keys, self.db.multi_get(list(map(self.k, keys)))): - yield (key, Serializer.loads(self.format, value.decode('utf-8'))) - - def get(self, key): - return Serializer.loads(self.format, self.db.get(self.k(key)).decode('utf-8')) - - def set(self, key, value): - self.db.put(self.k(key), Serializer.dumps(self.format, value)) - - def delete(self, key): - self.db.delete(self.k(key)) diff --git a/disco/bot/storage.py b/disco/bot/storage.py index 812d79c..5b03a07 100644 --- a/disco/bot/storage.py +++ b/disco/bot/storage.py @@ -1,26 +1,87 @@ -from .providers import load_provider +import os +from six.moves import UserDict -class Storage(object): - def __init__(self, ctx, config): +from disco.util.hashmap import HashMap +from disco.util.serializer import Serializer + + +class StorageHashMap(HashMap): + def __init__(self, data): + self.data = data + + +class ContextAwareProxy(UserDict): + def __init__(self, ctx): self.ctx = ctx - self.config = config - self.provider = load_provider(config.provider)(config.config) - self.provider.load() - self.root = self.provider.root() @property - def plugin(self): - return self.root.ensure('plugins').ensure(self.ctx['plugin'].name) + def data(self): + return self.ctx() - @property - def guild(self): - return self.plugin.ensure('guilds').ensure(self.ctx['guild'].id) - @property - def channel(self): - return self.plugin.ensure('channels').ensure(self.ctx['channel'].id) +class StorageDict(UserDict): + def __init__(self, parent, data): + self._parent = parent + self.data = data - @property - def user(self): - return self.plugin.ensure('users').ensure(self.ctx['user'].id) + def update(self, other): + self.data.update(other) + self._parent._update() + + def __setitem__(self, key, value): + self.data[key] = value + self._parent._update() + + def __delitem__(self, key): + del self.data[key] + self._parent._update() + + +class Storage(object): + def __init__(self, ctx, config): + self._ctx = ctx + self._path = config.path + self._serializer = config.serializer + self._fsync = config.fsync + self._data = {} + + if os.path.exists(self._path): + with open(self._path, 'r') as f: + self._data = Serializer.loads(self._serializer, f.read()) + + def __getitem__(self, key): + if key not in self._data: + self._data[key] = {} + return StorageHashMap(StorageDict(self, self._data[key])) + + def _update(self): + if self._fsync: + self.save() + + def save(self): + if not self._path: + return + + with open(self._path, 'w') as f: + f.write(Serializer.dumps(self._serializer, self._data)) + + def guild(self, key): + return ContextAwareProxy( + lambda: self['_g{}:{}'.format(self._ctx['guild'].id, key)] + ) + + def channel(self, key): + return ContextAwareProxy( + lambda: self['_c{}:{}'.format(self._ctx['channel'].id, key)] + ) + + def plugin(self, key): + return ContextAwareProxy( + lambda: self['_p{}:{}'.format(self._ctx['plugin'].name, key)] + ) + + def user(self, key): + return ContextAwareProxy( + lambda: self['_u{}:{}'.format(self._ctx['user'].id, key)] + ) diff --git a/disco/types/channel.py b/disco/types/channel.py index 1d373a3..48a10c1 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -244,7 +244,7 @@ class Channel(SlottedModel, Permissible): def create_webhook(self, name=None, avatar=None): return self.client.api.channels_webhooks_create(self.id, name, avatar) - def send_message(self, content=None, nonce=None, tts=False, attachment=None, embed=None): + def send_message(self, *args, **kwargs): """ Send a message in this channel. @@ -262,7 +262,7 @@ class Channel(SlottedModel, Permissible): :class:`disco.types.message.Message` The created message. """ - return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed) + return self.client.api.channels_messages_create(self.id, *args, **kwargs) def connect(self, *args, **kwargs): """ diff --git a/disco/util/sanitize.py b/disco/util/sanitize.py new file mode 100644 index 0000000..e12287e --- /dev/null +++ b/disco/util/sanitize.py @@ -0,0 +1,31 @@ +import re + + +# Zero width (non-rendering) space that can be used to escape mentions +ZERO_WIDTH_SPACE = u'\u200B' + +# A grave-looking character that can be used to escape codeblocks +MODIFIER_GRAVE_ACCENT = u'\u02CB' + +# Regex which matches all possible mention combinations, this may be over-zealous +# but its better safe than sorry. +MENTION_RE = re.compile('<[@|#][!|&]?([0-9]+)>|@everyone') + + +def _re_sub_mention(mention): + if '#' in mention: + return ZERO_WIDTH_SPACE.join(mention.split('#', 1)) + elif '@' in mention: + return ZERO_WIDTH_SPACE.join(mention.split('@', 1)) + else: + return mention + + +def S(text, escape_mentions=True, escape_codeblocks=False): + if escape_mentions: + text = MENTION_RE.sub(_re_sub_mention, text) + + if escape_codeblocks: + text = text.replace('`', MODIFIER_GRAVE_ACCENT) + + return text diff --git a/examples/storage.py b/examples/storage.py new file mode 100644 index 0000000..c8e5ce3 --- /dev/null +++ b/examples/storage.py @@ -0,0 +1,33 @@ +from disco.bot import Plugin + + +class BasicPlugin(Plugin): + def load(self, ctx): + super(BasicPlugin, self).load(ctx) + self.tags = self.storage.guild('tags') + + @Plugin.command('add', ' ', group='tags') + def on_tags_add(self, event, name, value): + if name in self.tags: + return event.msg.reply('That tag already exists!') + + self.tags[name] = value + return event.msg.reply(u':ok_hand: created the tag {}'.format(name), sanitize=True) + + @Plugin.command('get', '', group='tags') + def on_tags_get(self, event, name): + if name not in self.tags: + return event.msg.reply('That tag does not exist!') + + return event.msg.reply(self.tags[name], sanitize=True) + + @Plugin.command('delete', '', group='tags', aliases=['del', 'rmv', 'remove']) + def on_tags_delete(self, event, name): + if name not in self.tags: + return event.msg.reply('That tag does not exist!') + + del self.tags[name] + + return event.msg.reply(u':ok_hand: I deleted the {} tag for you'.format( + name + ), sanitize=True)