From 0335db6375f790cfc9ffe787933d2b408189f773 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 23 Sep 2016 02:40:59 -0500 Subject: [PATCH] Better pre/post hooking --- .gitignore | 3 ++ disco/bot/bot.py | 45 ++++++++++++++++----- disco/bot/command.py | 6 ++- disco/bot/plugin.py | 90 +++++++++++++++++++++++++++++++---------- disco/gateway/client.py | 23 ++++++----- disco/gateway/events.py | 7 +++- disco/state.py | 6 +-- disco/util/cache.py | 5 ++- requirements.txt | 4 +- setup.py | 34 ++++++++++++++++ 10 files changed, 172 insertions(+), 51 deletions(-) create mode 100644 .gitignore create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f123e2d --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +build/ +dist/ +disco.egg-info/ diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 18a9646..5c996d0 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -1,7 +1,15 @@ import re +from disco.client import DiscoClient + class BotConfig(object): + # Authentication token + token = None + + # Whether to enable command parsing + commands_enabled = True + # Whether the bot must be mentioned to respond to a command command_require_mention = True @@ -19,16 +27,23 @@ class BotConfig(object): # Whether an edited message can trigger a command command_allow_edit = True + # Function that when given a plugin name, returns its configuration + plugin_config_provider = None + class Bot(object): - def __init__(self, client, config=None): - self.client = client + def __init__(self, client=None, config=None): + self.client = client or DiscoClient(config.token) self.config = config or BotConfig() self.plugins = {} - self.client.events.on('MessageCreate', self.on_message_create) - self.client.events.on('MessageUpdate', self.on_message_update) + # Only bind event listeners if we're going to parse commands + if self.config.commands_enabled: + self.client.events.on('MessageCreate', self.on_message_create) + + if self.config.command_allow_edit: + self.client.events.on('MessageUpdate', self.on_message_update) # Stores the last message for every single channel self.last_message_cache = {} @@ -49,7 +64,7 @@ class Bot(object): else: self.command_matches_re = None - def handle_message(self, msg): + def get_commands_for_message(self, msg): content = msg.content if self.config.command_require_mention: @@ -61,20 +76,28 @@ class Bot(object): )))) if not match: - return False + raise StopIteration content = msg.without_mentions.strip() if self.config.command_prefix and not content.startswith(self.config.command_prefix): - return False + raise StopIteration + else: + content = content[len(self.config.command_prefix):] if not self.command_matches_re or not self.command_matches_re.match(content): - return False + raise StopIteration for command in self.commands: match = command.compiled_regex.match(content) if match: - command.execute(msg, match) + yield (command, match) + + def handle_message(self, msg): + commands = list(self.get_commands_for_message(msg)) + + if len(commands): + return any((command.execute(msg, match) for command, match in commands)) return False @@ -99,7 +122,9 @@ class Bot(object): if cls.__name__ in self.plugins: raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) - self.plugins[cls.__name__] = cls(self) + config = self.config.plugin_config_provider(cls.__name__) if self.config.plugin_config_provider else {} + + self.plugins[cls.__name__] = cls(self, config) self.plugins[cls.__name__].load() self.compute_command_matches_re() diff --git a/disco/bot/command.py b/disco/bot/command.py index e3248de..37fed7c 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -13,7 +13,8 @@ class CommandEvent(object): class Command(object): - def __init__(self, func, trigger, aliases=None, group=None, is_regex=False): + def __init__(self, plugin, func, trigger, aliases=None, group=None, is_regex=False): + self.plugin = plugin self.func = func self.triggers = [trigger] + (aliases or []) @@ -21,7 +22,8 @@ class Command(object): self.is_regex = is_regex def execute(self, msg, match): - self.func(CommandEvent(msg, match)) + event = CommandEvent(msg, match) + return self.func(event) @cached_property def compiled_regex(self): diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 7773b9a..370eab6 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,46 +1,72 @@ import inspect +import functools from disco.bot.command import Command class PluginDeco(object): @staticmethod - def listen(event_name): + def add_meta_deco(meta): def deco(f): if not hasattr(f, 'meta'): f.meta = [] - f.meta.append({ - 'type': 'listener', - 'event_name': event_name, - }) + f.meta.append(meta) return f return deco - @staticmethod - def command(*args, **kwargs): - def deco(f): - if not hasattr(f, 'meta'): - f.meta = [] - - f.meta.append({ - 'type': 'command', - 'args': args, - 'kwargs': kwargs, - }) - - return f - return deco + @classmethod + def listen(cls, event_name): + return cls.add_meta_deco({ + 'type': 'listener', + 'event_name': event_name, + }) + + @classmethod + def command(cls, *args, **kwargs): + return cls.add_meta_deco({ + 'type': 'command', + 'args': args, + 'kwargs': kwargs, + }) + + @classmethod + def pre_command(cls): + return cls.add_meta_deco({ + 'type': 'pre_command', + }) + + @classmethod + def post_command(cls): + return cls.add_meta_deco({ + 'type': 'post_command', + }) + + @classmethod + def pre_listener(cls): + return cls.add_meta_deco({ + 'type': 'pre_listener', + }) + + @classmethod + def post_listener(cls): + return cls.add_meta_deco({ + 'type': 'post_listener', + }) class Plugin(PluginDeco): - def __init__(self, bot): + def __init__(self, bot, config): self.bot = bot + self.config = config self.listeners = [] self.commands = [] + self._pre = {'command': [], 'listener': []} + self._post = {'command': [], 'listener': []} + for name, member in inspect.getmembers(self, predicate=inspect.ismethod): if hasattr(member, 'meta'): for meta in member.meta: @@ -48,12 +74,34 @@ class Plugin(PluginDeco): self.register_listener(member, meta['event_name']) elif meta['type'] == 'command': self.register_command(member, *meta['args'], **meta['kwargs']) + elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): + when, typ = meta['type'].split('_', 1) + self.register_trigger(typ, when, member) + + def register_trigger(self, typ, when, func): + getattr(self, '_' + when)[typ].append(func) + + def _dispatch(self, typ, func, event): + for pre in self._pre[typ]: + event = pre(event) + + if event is None: + return False + + result = func(event) + + for post in self._post[typ]: + post(event, result) + + return True def register_listener(self, func, name): + func = functools.partial(self._dispatch, 'listener', func) self.listeners.append(self.bot.client.events.on(name, func)) def register_command(self, func, *args, **kwargs): - self.commands.append(Command(func, *args, **kwargs)) + func = functools.partial(self._dispatch, 'command', func) + self.commands.append(Command(self, func, *args, **kwargs)) def destroy(self): map(lambda k: k.remove(), self._events) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 47eecc0..5ce6e1c 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -3,6 +3,7 @@ import gevent import json import zlib import six +import ssl from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket from disco.gateway.events import GatewayEvent @@ -37,6 +38,7 @@ class GatewayClient(LoggingClass): self.seq = 0 self.session_id = None self.reconnects = 0 + self.shutting_down = False # Cached gateway URL self._cached_gateway_url = None @@ -85,7 +87,7 @@ class GatewayClient(LoggingClass): self.session_id = ready.session_id self.reconnects = 0 - def connect(self): + def connect_and_run(self): if not self._cached_gateway_url: self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') @@ -98,6 +100,7 @@ class GatewayClient(LoggingClass): on_close=self.log_on_error('Error in on_close:', self.on_close), ) self.ws._get_close_args = websocket_get_close_args_override + self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_message(self, ws, msg): # Detect zlib and decompress @@ -130,6 +133,8 @@ class GatewayClient(LoggingClass): raise Exception('Unknown packet: {}'.format(data['op'])) def on_error(self, ws, error): + if isinstance(error, KeyboardInterrupt): + self.shutting_down = True raise Exception('WS recieved error: %s', error) def on_open(self, ws): @@ -145,6 +150,10 @@ class GatewayClient(LoggingClass): shard=[self.client.sharding['number'], self.client.sharding['total']])) def on_close(self, ws, code, reason): + if self.shutting_down: + self.log.info('WS Closed: shutting down') + return + self.reconnects += 1 self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) @@ -152,19 +161,15 @@ class GatewayClient(LoggingClass): raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) # Don't resume for these error codes - if 4000 <= code <= 4010: + if code and 4000 <= code <= 4010: self.session_id = None - self.log.info('Attempting fresh reconnect') - else: - self.log.info('Attempting resume') wait_time = self.reconnects * 5 - self.log.info('Will attempt to {} after {} seconds', 'resume' if self.session_id else 'reconnect', wait_time) + self.log.info('Will attempt to %s after %s seconds', 'resume' if self.session_id else 'reconnect', wait_time) gevent.sleep(wait_time) # Reconnect - self.connect() + self.connect_and_run() def run(self): - self.connect() - self.ws.run_forever() + self.connect_and_run() diff --git a/disco/gateway/events.py b/disco/gateway/events.py index eac3991..e15bc0f 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -152,10 +152,15 @@ class MessageDeleteBulk(GatewayEvent): class PresenceUpdate(GatewayEvent): + class Game(skema.Model): + type = skema.IntType() + name = skema.StringType() + url = skema.StringType(required=False) + user = skema.ModelType(User) guild_id = skema.SnowflakeType() roles = skema.ListType(skema.SnowflakeType()) - game = skema.StringType() + game = skema.ModelType(Game) status = skema.StringType() diff --git a/disco/state.py b/disco/state.py index cc7289e..d25e164 100644 --- a/disco/state.py +++ b/disco/state.py @@ -31,8 +31,7 @@ class State(object): self.channels[channel.id] = channel def on_guild_update(self, event): - # TODO - pass + self.guilds[event.guild.id] = event.guild def on_guild_delete(self, event): if event.guild_id in self.guilds: @@ -44,8 +43,7 @@ class State(object): self.channels[event.channel.id] = event.channel def on_channel_update(self, event): - # TODO - pass + self.channels[event.channel.id] = event.channel def on_channel_delete(self, event): if event.channel.id in self.channels: diff --git a/disco/util/cache.py b/disco/util/cache.py index 3ca33f2..5a86936 100644 --- a/disco/util/cache.py +++ b/disco/util/cache.py @@ -2,6 +2,7 @@ def cached_property(f): def deco(self, *args, **kwargs): - self.__dict__[f.__name__] = f(self, *args, **kwargs) - return self.__dict__[f.__name__] + if not hasattr(self, '__' + f.__name__): + setattr(self, '__' + f.__name__, f(self, *args, **kwargs)) + return getattr(self, '__' + f.__name__) return property(deco) diff --git a/requirements.txt b/requirements.txt index d17a9cd..99856b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ enum34==1.1.6 Flask==0.11.1 gevent==1.1.2 greenlet==0.4.10 -holster==0.0.7 +# holster==0.0.7 idna==2.1 inflection==0.3.1 ipaddress==1.0.17 @@ -19,7 +19,7 @@ pycparser==2.14 pyOpenSSL==16.1.0 requests==2.11.1 six==1.10.0 -skema==0.0.1 +# skema==0.0.1 websocket-client==0.37.0 Werkzeug==0.11.11 wheel==0.24.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..bef96b4 --- /dev/null +++ b/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup, find_packages + +from disco import VERSION + +with open('requirements.txt') as f: + requirements = f.readlines() + +with open('README.md') as f: + readme = f.read() + +setup( + name='disco', + author='b1nzy', + url='https://github.com/b1naryth1ef/disco', + version=VERSION, + packages=find_packages(), + license='MIT', + description='A Python library for Discord', + long_description=readme, + include_package_data=True, + install_requires=requirements, + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: MIT License', + 'Intended Audience :: Developers', + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Topic :: Internet', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: Utilities', + ])