From 7d5370234da8f711fdab7635c646f21b8322e6b4 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sun, 9 Oct 2016 00:33:50 -0500 Subject: [PATCH] todo: make small commits - Add the concept of storage backends, not fully fleshed out at this point, but a good starting point - Add a generic serializer - Move mention_nick to the GuildMember object (I'm not sure this was a good idea, but we'll see) - Add a default config loader to the bot - Fix some Python 2.x/3.x unicode stuff - Start tracking greenlets on the Plugin level, this will help with reloading when its fully completed - Fix manhole locals being basically empty (sans the bot if relevant) - Add Channel.delete_messages_bulk - Add GuildMember.owner to check if the member owns the server --- disco/bot/__init__.py | 3 +- disco/bot/backends/__init__.py | 8 ++++ disco/bot/backends/base.py | 20 ++++++++++ disco/bot/backends/disk.py | 35 ++++++++++++++++++ disco/bot/backends/memory.py | 18 +++++++++ disco/bot/bot.py | 67 +++++++++++++++++++++++++++++----- disco/bot/parser.py | 5 ++- disco/bot/plugin.py | 64 ++++++++++++++++++-------------- disco/bot/storage.py | 21 +++++++++++ disco/client.py | 8 +++- disco/types/base.py | 10 ++++- disco/types/channel.py | 5 +++ disco/types/guild.py | 10 +++++ disco/types/user.py | 4 -- disco/util/config.py | 42 +++++++++++++++++++++ disco/util/serializer.py | 32 ++++++++++++++++ examples/basic_plugin.py | 12 +++--- 17 files changed, 310 insertions(+), 54 deletions(-) create mode 100644 disco/bot/backends/__init__.py create mode 100644 disco/bot/backends/base.py create mode 100644 disco/bot/backends/disk.py create mode 100644 disco/bot/backends/memory.py create mode 100644 disco/bot/storage.py create mode 100644 disco/util/config.py create mode 100644 disco/util/serializer.py diff --git a/disco/bot/__init__.py b/disco/bot/__init__.py index 1ea73fa..135ac2b 100644 --- a/disco/bot/__init__.py +++ b/disco/bot/__init__.py @@ -1,4 +1,5 @@ from disco.bot.bot import Bot, BotConfig from disco.bot.plugin import Plugin +from disco.util.config import Config -__all__ = ['Bot', 'BotConfig', 'Plugin'] +__all__ = ['Bot', 'BotConfig', 'Plugin', 'Config'] diff --git a/disco/bot/backends/__init__.py b/disco/bot/backends/__init__.py new file mode 100644 index 0000000..b8df75b --- /dev/null +++ b/disco/bot/backends/__init__.py @@ -0,0 +1,8 @@ +from .memory import MemoryBackend +from .disk import DiskBackend + + +BACKENDS = { + 'memory': MemoryBackend, + 'disk': DiskBackend, +} diff --git a/disco/bot/backends/base.py b/disco/bot/backends/base.py new file mode 100644 index 0000000..fe005f8 --- /dev/null +++ b/disco/bot/backends/base.py @@ -0,0 +1,20 @@ + +class BaseStorageBackend(object): + def base(self): + return self.storage + + def __getitem__(self, key): + return self.storage[key] + + def __setitem__(self, key, value): + self.storage[key] = value + + def __delitem__(self, key): + del self.storage[key] + + +class StorageDict(dict): + def ensure(self, name): + if not dict.__contains__(self, name): + dict.__setitem__(self, name, StorageDict()) + return dict.__getitem__(self, name) diff --git a/disco/bot/backends/disk.py b/disco/bot/backends/disk.py new file mode 100644 index 0000000..7de40b7 --- /dev/null +++ b/disco/bot/backends/disk.py @@ -0,0 +1,35 @@ +import os + +from .base import BaseStorageBackend, StorageDict + + +class DiskBackend(BaseStorageBackend): + def __init__(self, config): + self.format = config.get('format', 'json') + self.path = config.get('path', 'storage') + '.' + self.format + self.storage = StorageDict() + + @staticmethod + def get_format_functions(fmt): + if fmt == 'json': + from json import loads, dumps + return (loads, dumps) + elif fmt == 'yaml': + from pyyaml import load, dump + return (load, dump) + raise Exception('Unsupported format type {}'.format(fmt)) + + def load(self): + if not os.path.exists(self.path): + return + + decode, _ = self.get_format_functions(self.format) + + with open(self.path, 'r') as f: + self.storage = decode(f.read()) + + def dump(self): + _, encode = self.get_format_functions(self.format) + + with open(self.path, 'w') as f: + f.write(encode(self.storage)) diff --git a/disco/bot/backends/memory.py b/disco/bot/backends/memory.py new file mode 100644 index 0000000..5db93b9 --- /dev/null +++ b/disco/bot/backends/memory.py @@ -0,0 +1,18 @@ +from .base import BaseStorageBackend, StorageDict + + +class MemoryBackend(BaseStorageBackend): + def __init__(self): + self.storage = StorageDict() + + def base(self): + return self.storage + + def __getitem__(self, key): + return self.storage[key] + + def __setitem__(self, key, value): + self.storage[key] = value + + def __delitem__(self, key): + del self.storage[key] diff --git a/disco/bot/bot.py b/disco/bot/bot.py index c9b3c1a..ef49fb5 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -1,4 +1,5 @@ import re +import os import importlib import inspect @@ -7,10 +8,12 @@ from holster.threadlocal import ThreadLocal from disco.bot.plugin import Plugin from disco.bot.command import CommandEvent -# from disco.bot.storage import Storage +from disco.bot.storage import Storage +from disco.util.config import Config +from disco.util.serializer import Serializer -class BotConfig(object): +class BotConfig(Config): """ An object which is used to configure and define the runtime configuration for a bot. @@ -40,9 +43,14 @@ class BotConfig(object): message in a channel, and did not previously trigger a command. This is helpful for allowing edits to typod commands. plugin_config_provider : Optional[function] - If set, this function will be called before loading a plugin, with the - plugins class. Its expected to return a type of configuration object the - plugin understands. + If set, this function will replace the default configuration loading + function, which normally attempts to load a file located at config/plugin_name.fmt + where fmt is the plugin_config_format. The function here should return + a valid configuration object which the plugin understands. + plugin_config_format : str + The serilization format plugin configuration files are in. + plugin_config_dir : str + The directory plugin configuration is located within. """ token = None @@ -58,6 +66,13 @@ class BotConfig(object): commands_allow_edit = True plugin_config_provider = None + plugin_config_format = 'yaml' + plugin_config_dir = 'config' + + storage_enabled = False + storage_backend = 'memory' + storage_autosave = True + storage_autosave_interval = 120 class Bot(object): @@ -90,7 +105,9 @@ class Bot(object): self.ctx = ThreadLocal() # The storage object acts as a dynamic contextual aware store - # self.storage = Storage(self.ctx) + self.storage = None + if self.config.storage_enabled: + self.storage = Storage(self.ctx, self.config.from_prefix('storage')) if self.client.config.manhole_enable: self.client.manhole_locals['bot'] = self @@ -181,8 +198,12 @@ class Bot(object): raise StopIteration if mention_direct: - content = content.replace(self.client.state.me.mention, '', 1) - content = content.replace(self.client.state.me.mention_nick, '', 1) + if msg.guild: + member = msg.guild.get_member(self.client.state.me) + if member: + content = content.replace(member.mention, '', 1) + else: + content = content.replace(self.client.state.me.mention, '', 1) elif mention_everyone: content = content.replace('@everyone', '', 1) else: @@ -265,8 +286,11 @@ class Bot(object): if cls.__name__ in self.plugins: raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) - if not config and callable(self.config.plugin_config_provider): - config = self.config.plugin_config_provider(cls) + if not config: + if callable(self.config.plugin_config_provider): + config = self.config.plugin_config_provider(cls) + else: + config = self.load_plugin_config(cls) self.plugins[cls.__name__] = cls(self, config) self.plugins[cls.__name__].load() @@ -317,3 +341,26 @@ class Bot(object): break else: raise Exception('Could not find any plugins to load within module {}'.format(path)) + + def load_plugin_config(self, cls): + name = cls.__name__.lower() + if name.startswith('plugin'): + name = name[6:] + + path = os.path.join( + self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format + + if not os.path.exists(path): + if hasattr(cls, 'config_cls'): + return cls.config_cls() + return + + with open(path, 'r') as f: + data = Serializer.loads(self.config.plugin_config_format, f.read()) + + if hasattr(cls, 'config_cls'): + inst = cls.config_cls() + inst.update(data) + return inst + + return data diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 0ae2bb2..8f3483e 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -1,4 +1,5 @@ import re +import six import copy @@ -7,7 +8,7 @@ PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])') # Mapping of types TYPE_MAP = { - 'str': lambda ctx, data: str(data), + 'str': lambda ctx, data: str(data) if six.PY3 else unicode(data), 'int': lambda ctx, data: int(data), 'float': lambda ctx, data: int(data), 'snowflake': lambda ctx, data: int(data), @@ -160,7 +161,7 @@ class ArgumentSet(object): try: raw[idx] = self.convert(ctx, arg.types, r) except: - raise ArgumentError('cannot convert `{}` to `{}`'.format( + raise ArgumentError(u'cannot convert `{}` to `{}`'.format( r, ', '.join(arg.types) )) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 2362bbb..87f0e0b 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,7 +1,7 @@ import inspect import functools import gevent -import os +import weakref from holster.emitter import Priority @@ -27,6 +27,16 @@ class PluginDeco(object): return f return deco + @classmethod + def with_config(cls, config_cls): + """ + Sets the plugins config class to the specified config class. + """ + def deco(plugin_cls): + plugin_cls.config_cls = config_cls + return plugin_cls + return deco + @classmethod def listen(cls, event_name, priority=None): """ @@ -86,13 +96,14 @@ class PluginDeco(object): }) @classmethod - def schedule(cls, interval=60): + def schedule(cls, *args, **kwargs): """ Runs a function repeatedly, waiting for a specified interval """ return cls.add_meta_deco({ 'type': 'schedule', - 'interval': interval, + 'args': args, + 'kwargs': kwargs, }) @@ -131,10 +142,15 @@ class Plugin(LoggingClass, PluginDeco): self.listeners = [] self.commands = {} self.schedules = {} + self.greenlets = weakref.WeakSet() self._pre = {'command': [], 'listener': []} self._post = {'command': [], 'listener': []} + # TODO: when handling events/commands we need to track the greenlet in + # the greenlets set so we can termiante long running commands/listeners + # on reload. + for name, member in inspect.getmembers(self, predicate=inspect.ismethod): if hasattr(member, 'meta'): for meta in member.meta: @@ -143,11 +159,16 @@ class Plugin(LoggingClass, PluginDeco): elif meta['type'] == 'command': self.register_command(member, *meta['args'], **meta['kwargs']) elif meta['type'] == 'schedule': - self.register_schedule(member, meta['interval']) + self.register_schedule(member, *meta['args'], **meta['kwargs']) elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): when, typ = meta['type'].split('_', 1) self.register_trigger(typ, when, member) + def spawn(self, method, *args, **kwargs): + obj = gevent.spawn(method, *args, **kwargs) + self.greenlets.add(obj) + return obj + def execute(self, event): """ Executes a CommandEvent this plugin owns @@ -217,7 +238,7 @@ class Plugin(LoggingClass, PluginDeco): wrapped = functools.partial(self._dispatch, 'command', func) self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs) - def register_schedule(self, func, interval): + def register_schedule(self, func, interval, repeat=True, init=True): """ Registers a function to be called repeatedly, waiting for an interval duration. @@ -230,11 +251,16 @@ class Plugin(LoggingClass, PluginDeco): Interval (in seconds) to repeat the function on. """ def repeat(): - while True: + if init: func() + + while True: gevent.sleep(interval) + func() + if not repeat: + break - self.schedules[func.__name__] = gevent.spawn(repeat) + self.schedules[func.__name__] = self.spawn(repeat) def load(self): """ @@ -246,6 +272,9 @@ class Plugin(LoggingClass, PluginDeco): """ Called when the plugin is unloaded """ + for greenlet in self.greenlets: + greenlet.kill() + for listener in self.listeners: listener.remove() @@ -254,24 +283,3 @@ class Plugin(LoggingClass, PluginDeco): def reload(self): self.bot.reload_plugin(self.__class__) - - @staticmethod - def load_config_from_path(cls, path, format='json'): - inst = cls() - - if not os.path.exists(path): - return inst - - with open(path, 'r') as f: - data = f.read() - - if format == 'json': - import json - inst.__dict__.update(json.loads(data)) - elif format == 'yaml': - import yaml - inst.__dict__.update(yaml.load(data)) - else: - raise Exception('Unsupported config format {}'.format(format)) - - return inst diff --git a/disco/bot/storage.py b/disco/bot/storage.py new file mode 100644 index 0000000..45fa1f6 --- /dev/null +++ b/disco/bot/storage.py @@ -0,0 +1,21 @@ +from .backends import BACKENDS + + +class Storage(object): + def __init__(self, ctx, config): + self.ctx = ctx + self.backend = BACKENDS[config.backend] + # TODO: autosave + # config.autosave config.autosave_interval + + @property + def guild(self): + return self.backend.base().ensure('guilds').ensure(self.ctx['guild'].id) + + @property + def channel(self): + return self.backend.base().ensure('channels').ensure(self.ctx['channel'].id) + + @property + def user(self): + return self.backend.base().ensure('users').ensure(self.ctx['user'].id) diff --git a/disco/client.py b/disco/client.py index 24bc9a3..cfadc80 100644 --- a/disco/client.py +++ b/disco/client.py @@ -85,7 +85,13 @@ class Client(object): self.gw = GatewayClient(self, self.config.encoding_cls) if self.config.manhole_enable: - self.manhole_locals = {} + self.manhole_locals = { + 'client': self, + 'state': self.state, + 'api': self.api, + 'gw': self.gw + } + self.manhole = DiscoBackdoorServer(self.config.manhole_bind, banner='Disco Manhole', localf=lambda: self.manhole_locals) diff --git a/disco/types/base.py b/disco/types/base.py index 0f8899d..cb95f79 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -111,11 +111,17 @@ def datetime(data): def text(obj): - return six.text_type(obj) if obj else six.text_type() + if six.PY2: + return unicode(obj) + else: + return str(obj) def binary(obj): - return six.text_type(obj) if obj else six.text_type() + if six.PY2: + return unicode(obj) + else: + return bytes(obj) def field(typ, alias=None): diff --git a/disco/types/channel.py b/disco/types/channel.py index 6bc2330..92d3b6a 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -3,6 +3,7 @@ from holster.enum import Enum from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text from disco.types.permissions import PermissionValue +from disco.util import to_snowflake from disco.util.functional import cached_property from disco.types.user import User from disco.types.permissions import Permissions, Permissible @@ -241,6 +242,10 @@ class Channel(Model, Permissible): def delete_overwrite(self, ow): self.client.api.channels_permissions_delete(self.id, ow.id) + def delete_messages_bulk(self, messages): + messages = map(to_snowflake, messages) + self.client.api.channels_messages_delete_bulk(self.id, messages) + class MessageIterator(object): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index 5d8441b..9e38582 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -156,6 +156,16 @@ class GuildMember(Model): roles = self.roles + [role.id] self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles) + @property + def owner(self): + return self.guild.owner_id == self.id + + @property + def mention(self): + if self.nick: + return '<@!{}>'.format(self.id) + return self.user.mention + @cached_property def guild(self): return self.client.state.guilds.get(self.guild_id) diff --git a/disco/types/user.py b/disco/types/user.py index 7281796..ccd2e5b 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -13,10 +13,6 @@ class User(Model): def mention(self): return '<@{}>'.format(self.id) - @property - def mention_nick(self): - return '<@!{}>'.format(self.id) - def to_string(self): return '{}#{}'.format(self.username, self.discriminator) diff --git a/disco/util/config.py b/disco/util/config.py new file mode 100644 index 0000000..74927c8 --- /dev/null +++ b/disco/util/config.py @@ -0,0 +1,42 @@ +import os +import six + +from .serializer import Serializer + + +class Config(object): + def __init__(self, obj=None): + self.__dict__.update({ + k: getattr(self, k) for k in dir(self.__class__) + }) + + if obj: + self.__dict__.update(obj) + + @classmethod + def from_file(cls, path): + inst = cls() + + with open(path, 'r') as f: + data = f.read() + + _, ext = os.path.splitext(path) + Serializer.check_format(ext) + inst.__dict__.update(Serializer.load(ext, data)) + return inst + + def from_prefix(self, prefix): + prefix = prefix + '_' + obj = {} + + for k, v in six.iteritems(self.__dict__): + if k.startswith(prefix): + obj[k[len(prefix):]] = v + + return obj + + def update(self, other): + if isinstance(other, Config): + other = other.__dict__ + + self.__dict__.update(other) diff --git a/disco/util/serializer.py b/disco/util/serializer.py new file mode 100644 index 0000000..565d513 --- /dev/null +++ b/disco/util/serializer.py @@ -0,0 +1,32 @@ + + +class Serializer(object): + FORMATS = { + 'json', + 'yaml' + } + + @classmethod + def check_format(cls, fmt): + if fmt not in cls.FORMATS: + raise Exception('Unsupported serilization format: {}'.format(fmt)) + + @staticmethod + def json(): + from json import loads, dumps + return (loads, dumps) + + @staticmethod + def yaml(): + from yaml import load, dump + return (load, dump) + + @classmethod + def loads(cls, fmt, raw): + loads, _ = getattr(cls, fmt)() + return loads(raw) + + @classmethod + def dumps(cls, fmt, raw): + _, dumps = getattr(cls, fmt)() + return dumps(raw) diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 6ed449f..8001a69 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -95,15 +95,15 @@ class BasicPlugin(Plugin): json.dumps(perms.to_dict(), sort_keys=True, indent=2, separators=(',', ': ')) )) - """ @Plugin.command('tag', ' [value:str]') def on_tag(self, event, name, value=None): + tags = self.storage.guild.ensure('tags') + if value: - self.storage.guild['tags'][name] = value + tags[name] = value event.msg.reply(':ok_hand:') else: - if name in self.storage.guild['tags']: - return event.msg.reply(self.storage.guild['tags'][name]) + if name in tags: + return event.msg.reply(tags[name]) else: - event.msg.reply('Unknown tag `{}`'.format(name)) - """ + return event.msg.reply('Unknown tag: `{}`'.format(name))