From c4d4b40107596c0c986103a2c5c182667bb4e43c Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 22 Sep 2016 04:18:57 -0500 Subject: [PATCH] plugins, commands, etc --- disco/bot/__init__.py | 1 + disco/bot/bot.py | 116 +++++++++++++++++++++++++++++++++++++++ disco/bot/command.py | 29 ++++++++++ disco/bot/plugin.py | 65 ++++++++++++++++++++++ disco/cli.py | 8 +-- disco/client.py | 7 +++ disco/gateway/client.py | 14 ++--- disco/gateway/events.py | 12 +++- disco/state.py | 52 ++++++++++++++++++ disco/types/base.py | 5 ++ disco/types/channel.py | 14 ++++- disco/types/guild.py | 22 ++++++-- disco/types/message.py | 54 +++++++++++++++++- disco/types/user.py | 4 +- disco/types/voice.py | 4 +- disco/util/__init__.py | 18 ++++++ disco/util/cache.py | 7 +++ examples/basic_plugin.py | 18 ++++++ 18 files changed, 423 insertions(+), 27 deletions(-) create mode 100644 disco/bot/bot.py create mode 100644 disco/bot/command.py create mode 100644 disco/bot/plugin.py create mode 100644 disco/types/base.py create mode 100644 disco/util/cache.py create mode 100644 examples/basic_plugin.py diff --git a/disco/bot/__init__.py b/disco/bot/__init__.py index e69de29..e8337f1 100644 --- a/disco/bot/__init__.py +++ b/disco/bot/__init__.py @@ -0,0 +1 @@ +from disco.bot.bot import Bot diff --git a/disco/bot/bot.py b/disco/bot/bot.py new file mode 100644 index 0000000..bc65b6d --- /dev/null +++ b/disco/bot/bot.py @@ -0,0 +1,116 @@ +import re + + +class BotConfig(object): + # Whether the bot must be mentioned to respond to a command + command_require_mention = True + + # Rules about what mentions trigger the bot + command_mention_rules = { + # 'here': False, + 'everyone': False, + 'role': True, + 'user': True, + } + + # The prefix required for EVERY command + command_prefix = '' + + # Whether an edited message can trigger a command + command_allow_edit = True + + +class Bot(object): + def __init__(self, client, config=None): + self.client = client + self.config = config or BotConfig() + + self.plugins = {} + + self.client.events.on('MessageCreate', self.on_message_create) + self.client.events.on('MessageUpdate', self.on_message_update) + + # Stores the last message for every single channel + self.last_message_cache = {} + + # Stores a giant regex matcher for all commands + self.command_matches_re = None + + @property + def commands(self): + for plugin in self.plugins.values(): + for command in plugin.commands: + yield command + + def compute_command_matches_re(self): + re_str = '|'.join(command.regex for command in self.commands) + print re_str + if re_str: + self.command_matches_re = re.compile(re_str) + else: + self.command_matches_re = None + + def handle_message(self, msg): + content = msg.content + + if self.config.command_require_mention: + match = any(( + self.config.command_mention_rules['user'] and msg.is_mentioned(self.client.state.me), + self.config.command_mention_rules['everyone'] and msg.mention_everyone, + self.config.command_mention_rules['role'] and any(map(msg.is_mentioned, + msg.guild.get_member(self.client.state.me).roles + )))) + + if not match: + return False + + content = msg.without_mentions.strip() + + if self.config.command_prefix and not content.startswith(self.config.command_prefix): + return False + + if not self.command_matches_re or not self.command_matches_re.match(content): + return False + + for command in self.commands: + if command.compiled_regex.match(content): + command.execute(msg) + + return False + + def on_message_create(self, event): + if self.config.command_allow_edit: + self.last_message_cache[event.message.channel_id] = (event.message, False) + + self.handle_message(event.message) + + def on_message_update(self, event): + if self.config.command_allow_edit: + msg = self.last_message_cache.get(event.message.channel_id) + if msg and event.message.id == msg[0].id: + triggered = msg[1] + + if not triggered: + triggered = self.handle_message(event.message) + + self.last_message_cache[event.message.channel_id] = (event.message, triggered) + + def add_plugin(self, cls): + if cls.__name__ in self.plugins: + raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) + + self.plugins[cls.__name__] = cls(self) + self.plugins[cls.__name__].load() + self.compute_command_matches_re() + + def rmv_plugin(self, cls): + if cls.__name__ not in self.plugins: + raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__)) + + self.plugins[cls.__name__].unload() + self.plugins[cls.__name__].destroy() + del self.plugins[cls.__name__] + self.compute_command_matches_re() + + def run_forever(self): + self.client.run_forever() diff --git a/disco/bot/command.py b/disco/bot/command.py new file mode 100644 index 0000000..85918b0 --- /dev/null +++ b/disco/bot/command.py @@ -0,0 +1,29 @@ +import re + +from disco.util.cache import cached_property + +ARGS_REGEX = '( (.*)$|$)' + + +class Command(object): + def __init__(self, func, trigger, aliases=None, group=None, is_regex=False): + self.func = func + self.triggers = [trigger] + (aliases or []) + + self.group = group + self.is_regex = is_regex + + def execute(self, msg): + self.func(msg) + + @cached_property + def compiled_regex(self): + return re.compile(self.regex) + + @property + def regex(self): + if self.is_regex: + return '|'.join(self.triggers) + else: + group = self.group + ' ' if self.group else '' + return '|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py new file mode 100644 index 0000000..7773b9a --- /dev/null +++ b/disco/bot/plugin.py @@ -0,0 +1,65 @@ +import inspect + +from disco.bot.command import Command + + +class PluginDeco(object): + @staticmethod + def listen(event_name): + def deco(f): + if not hasattr(f, 'meta'): + f.meta = [] + + f.meta.append({ + 'type': 'listener', + 'event_name': event_name, + }) + + return f + return deco + + @staticmethod + def command(*args, **kwargs): + def deco(f): + if not hasattr(f, 'meta'): + f.meta = [] + + f.meta.append({ + 'type': 'command', + 'args': args, + 'kwargs': kwargs, + }) + + return f + return deco + + +class Plugin(PluginDeco): + def __init__(self, bot): + self.bot = bot + + self.listeners = [] + self.commands = [] + + for name, member in inspect.getmembers(self, predicate=inspect.ismethod): + if hasattr(member, 'meta'): + for meta in member.meta: + if meta['type'] == 'listener': + self.register_listener(member, meta['event_name']) + elif meta['type'] == 'command': + self.register_command(member, *meta['args'], **meta['kwargs']) + + def register_listener(self, func, name): + self.listeners.append(self.bot.client.events.on(name, func)) + + def register_command(self, func, *args, **kwargs): + self.commands.append(Command(func, *args, **kwargs)) + + def destroy(self): + map(lambda k: k.remove(), self._events) + + def load(self): + pass + + def unload(self): + pass diff --git a/disco/cli.py b/disco/cli.py index 90cc701..5c05fb9 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -3,14 +3,14 @@ import argparse from gevent import monkey +monkey.patch_all() parser = argparse.ArgumentParser() parser.add_argument('--token', help='Bot Authentication Token', required=True) logging.basicConfig(level=logging.INFO) -def main(): - monkey.patch_all() +def disco_main(): args = parser.parse_args() from disco.util.token import is_valid_token @@ -20,7 +20,7 @@ def main(): return from disco.client import DiscoClient - DiscoClient(args.token).run_forever() + return DiscoClient(args.token) if __name__ == '__main__': - main() + disco_main().run_forever() diff --git a/disco/client.py b/disco/client.py index f2e3aff..d4df839 100644 --- a/disco/client.py +++ b/disco/client.py @@ -1,5 +1,9 @@ import logging +import gevent +from holster.emitter import Emitter + +from disco.state import State from disco.api.client import APIClient from disco.gateway.client import GatewayClient @@ -12,6 +16,9 @@ class DiscoClient(object): self.token = token self.sharding = sharding or {'number': 0, 'total': 1} + self.events = Emitter(gevent.spawn) + + self.state = State(self) self.api = APIClient(self) self.gw = GatewayClient(self) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index dc65919..a677777 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -3,11 +3,8 @@ import gevent import json import zlib -from holster.emitter import Emitter -# from holster.util import SimpleObject - from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket -from disco.gateway.events import GatewayEvent, Ready +from disco.gateway.events import GatewayEvent from disco.util.logging import LoggingClass GATEWAY_VERSION = 6 @@ -28,9 +25,8 @@ class GatewayClient(LoggingClass): def __init__(self, client): super(GatewayClient, self).__init__() self.client = client - self.emitter = Emitter(gevent.spawn) - self.emitter.on(Ready, self.on_ready) + self.client.events.on('Ready', self.on_ready) # Websocket connection self.ws = None @@ -60,9 +56,9 @@ class GatewayClient(LoggingClass): gevent.sleep(interval / 1000) def handle_dispatch(self, packet): - cls, obj = GatewayEvent.from_dispatch(packet) - self.log.info('Dispatching %s', cls) - self.emitter.emit(cls, obj) + obj = GatewayEvent.from_dispatch(self.client, packet) + self.log.info('Dispatching %s', obj.__class__.__name__) + self.client.events.emit(obj.__class__.__name__, obj) def handle_heartbeat(self, packet): pass diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 8b6c079..eac3991 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -1,17 +1,25 @@ import inflection import skema +from disco.util import recursive_find_matching +from disco.types.base import BaseType from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState class GatewayEvent(skema.Model): @staticmethod - def from_dispatch(obj): + def from_dispatch(client, obj): cls = globals().get(inflection.camelize(obj['t'].lower())) if not cls: raise Exception('Could not find cls for {}'.format(obj['t'])) - return cls, cls.create(obj['d']) + obj = cls.create(obj['d']) + + # TODO: use skema info + for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)): + item.client = client + + return obj @classmethod def create(cls, obj): diff --git a/disco/state.py b/disco/state.py index e69de29..cc7289e 100644 --- a/disco/state.py +++ b/disco/state.py @@ -0,0 +1,52 @@ + + +class State(object): + def __init__(self, client): + self.client = client + + self.me = None + + self.channels = {} + self.guilds = {} + + self.client.events.on('Ready', self.on_ready) + + # Guilds + self.client.events.on('GuildCreate', self.on_guild_create) + self.client.events.on('GuildUpdate', self.on_guild_update) + self.client.events.on('GuildDelete', self.on_guild_delete) + + # Channels + self.client.events.on('ChannelCreate', self.on_channel_create) + self.client.events.on('ChannelUpdate', self.on_channel_update) + self.client.events.on('ChannelDelete', self.on_channel_delete) + + def on_ready(self, event): + self.me = event.user + + def on_guild_create(self, event): + self.guilds[event.guild.id] = event.guild + + for channel in event.guild.channels: + self.channels[channel.id] = channel + + def on_guild_update(self, event): + # TODO + pass + + def on_guild_delete(self, event): + if event.guild_id in self.guilds: + del self.guilds[event.guild_id] + + # CHANNELS? + + def on_channel_create(self, event): + self.channels[event.channel.id] = event.channel + + def on_channel_update(self, event): + # TODO + pass + + def on_channel_delete(self, event): + if event.channel.id in self.channels: + del self.channels[event.channel.id] diff --git a/disco/types/base.py b/disco/types/base.py new file mode 100644 index 0000000..fb590a9 --- /dev/null +++ b/disco/types/base.py @@ -0,0 +1,5 @@ +import skema + + +class BaseType(skema.Model): + pass diff --git a/disco/types/channel.py b/disco/types/channel.py index 4ae1465..abbf1f6 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -2,7 +2,8 @@ import skema from holster.enum import Enum -# from disco.types.guild import Guild +from disco.util.cache import cached_property +from disco.types.base import BaseType from disco.types.user import User @@ -19,7 +20,7 @@ PermissionOverwriteType = Enum( ) -class PermissionOverwrite(skema.Model): +class PermissionOverwrite(BaseType): id = skema.SnowflakeType() type = skema.StringType(choices=PermissionOverwriteType.ALL_VALUES) @@ -27,8 +28,9 @@ class PermissionOverwrite(skema.Model): deny = skema.IntType() -class Channel(skema.Model): +class Channel(BaseType): id = skema.SnowflakeType() + guild_id = skema.SnowflakeType(required=False) name = skema.StringType() topic = skema.StringType() @@ -40,3 +42,9 @@ class Channel(skema.Model): type = skema.IntType(choices=ChannelType.ALL_VALUES) permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite)) + + @cached_property + def guild(self): + print self.guild_id + print self.client.state.guilds + return self.client.state.guilds.get(self.guild_id) diff --git a/disco/types/guild.py b/disco/types/guild.py index 8ff6835..fa984c4 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,12 +1,14 @@ import skema +from disco.util.cache import cached_property +from disco.types.base import BaseType from disco.util.types import PreHookType from disco.types.user import User from disco.types.voice import VoiceState from disco.types.channel import Channel -class Emoji(skema.Model): +class Emoji(BaseType): id = skema.SnowflakeType() name = skema.StringType() require_colons = skema.BooleanType() @@ -14,7 +16,7 @@ class Emoji(skema.Model): roles = skema.ListType(skema.SnowflakeType()) -class Role(skema.Model): +class Role(BaseType): id = skema.SnowflakeType() name = skema.StringType() hoist = skema.BooleanType() @@ -24,7 +26,7 @@ class Role(skema.Model): position = skema.IntType() -class GuildMember(skema.Model): +class GuildMember(BaseType): user = skema.ModelType(User) mute = skema.BooleanType() deaf = skema.BooleanType() @@ -32,7 +34,7 @@ class GuildMember(skema.Model): roles = skema.ListType(skema.SnowflakeType()) -class Guild(skema.Model): +class Guild(BaseType): id = skema.SnowflakeType() owner_id = skema.SnowflakeType() @@ -56,3 +58,15 @@ class Guild(skema.Model): channels = skema.ListType(skema.ModelType(Channel)) roles = skema.ListType(skema.ModelType(Role)) emojis = skema.ListType(skema.ModelType(Emoji)) + + @cached_property + def members_dict(self): + return {i.user.id: i for i in self.members} + + def get_member(self, user): + return self.members_dict.get(user.id) + + def validate_channels(self, ctx): + if self.channels: + for channel in self.channels: + channel.guild_id = self.id diff --git a/disco/types/message.py b/disco/types/message.py index 534f299..27e89b8 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,17 +1,21 @@ +import re import skema +from disco.util.cache import cached_property from disco.util.types import PreHookType +from disco.types.base import BaseType from disco.types.user import User +from disco.types.guild import Role -class MessageEmbed(skema.Model): +class MessageEmbed(BaseType): title = skema.StringType() type = skema.StringType() description = skema.StringType() url = skema.StringType() -class MessageAttachment(skema.Model): +class MessageAttachment(BaseType): id = skema.SnowflakeType() filename = skema.StringType() url = skema.StringType() @@ -21,7 +25,7 @@ class MessageAttachment(skema.Model): width = skema.IntType() -class Message(skema.Model): +class Message(BaseType): id = skema.SnowflakeType() channel_id = skema.SnowflakeType() @@ -42,3 +46,47 @@ class Message(skema.Model): embeds = skema.ListType(skema.ModelType(MessageEmbed)) attachment = skema.ListType(skema.ModelType(MessageAttachment)) + + @cached_property + def guild(self): + return self.channel.guild + + @cached_property + def channel(self): + print self.client.state.channels + return self.client.state.channels.get(self.channel_id) + + @cached_property + def mention_users(self): + return [i.id for i in self.mentions] + + @cached_property + def mention_users_dict(self): + return {i.id: i for i in self.mentions} + + def is_mentioned(self, entity): + if isinstance(entity, User): + return entity.id in self.mention_users + elif isinstance(entity, Role): + return entity.id in self.mention_roles + else: + raise Exception('Unknown entity: {}'.format(entity)) + + @cached_property + def without_mentions(self): + return self.replace_mentions( + lambda u: '', + lambda r: '') + + def replace_mentions(self, user_replace, role_replace): + if not self.mentions and not self.mention_roles: + return + + def replace(match): + id = match.group(0) + if id in self.mention_roles: + return role_replace(id) + else: + return user_replace(self.mention_users_dict.get(id)) + + return re.sub('<@!?([0-9]+)>', replace, self.content) diff --git a/disco/types/user.py b/disco/types/user.py index f1936f5..6b24c9e 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -1,7 +1,9 @@ import skema +from disco.types.base import BaseType -class User(skema.Model): + +class User(BaseType): id = skema.SnowflakeType() username = skema.StringType() diff --git a/disco/types/voice.py b/disco/types/voice.py index 91b1ea5..997a77d 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -1,5 +1,7 @@ import skema +from disco.types.base import BaseType -class VoiceState(skema.Model): + +class VoiceState(BaseType): id = skema.SnowflakeType() diff --git a/disco/util/__init__.py b/disco/util/__init__.py index e69de29..bf2ff3c 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -0,0 +1,18 @@ + + +def recursive_find_matching(base, match_clause): + result = [] + + if hasattr(base, '__dict__'): + values = base.__dict__.values() + else: + values = list(base) + + for v in values: + if match_clause(v): + result.append(v) + + if hasattr(v, '__dict__') or hasattr(v, '__iter__'): + result += recursive_find_matching(v, match_clause) + + return result diff --git a/disco/util/cache.py b/disco/util/cache.py new file mode 100644 index 0000000..3ca33f2 --- /dev/null +++ b/disco/util/cache.py @@ -0,0 +1,7 @@ + + +def cached_property(f): + def deco(self, *args, **kwargs): + self.__dict__[f.__name__] = f(self, *args, **kwargs) + return self.__dict__[f.__name__] + return property(deco) diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py new file mode 100644 index 0000000..6a05979 --- /dev/null +++ b/examples/basic_plugin.py @@ -0,0 +1,18 @@ +from disco.cli import disco_main +from disco.bot import Bot +from disco.bot.plugin import Plugin + + +class BasicPlugin(Plugin): + @Plugin.listen('MessageCreate') + def on_message_create(self, event): + print 'Message Created: {}'.format(event.message.content) + + @Plugin.command('test') + def on_test_command(self, event): + print 'wtf' + +if __name__ == '__main__': + bot = Bot(disco_main()) + bot.add_plugin(BasicPlugin) + bot.run_forever()