Browse Source

Better pre/post hooking

pull/3/head
Andrei 9 years ago
parent
commit
0335db6375
  1. 3
      .gitignore
  2. 45
      disco/bot/bot.py
  3. 6
      disco/bot/command.py
  4. 90
      disco/bot/plugin.py
  5. 23
      disco/gateway/client.py
  6. 7
      disco/gateway/events.py
  7. 6
      disco/state.py
  8. 5
      disco/util/cache.py
  9. 4
      requirements.txt
  10. 34
      setup.py

3
.gitignore

@ -0,0 +1,3 @@
build/
dist/
disco.egg-info/

45
disco/bot/bot.py

@ -1,7 +1,15 @@
import re import re
from disco.client import DiscoClient
class BotConfig(object): class BotConfig(object):
# Authentication token
token = None
# Whether to enable command parsing
commands_enabled = True
# Whether the bot must be mentioned to respond to a command # Whether the bot must be mentioned to respond to a command
command_require_mention = True command_require_mention = True
@ -19,16 +27,23 @@ class BotConfig(object):
# Whether an edited message can trigger a command # Whether an edited message can trigger a command
command_allow_edit = True command_allow_edit = True
# Function that when given a plugin name, returns its configuration
plugin_config_provider = None
class Bot(object): class Bot(object):
def __init__(self, client, config=None): def __init__(self, client=None, config=None):
self.client = client self.client = client or DiscoClient(config.token)
self.config = config or BotConfig() self.config = config or BotConfig()
self.plugins = {} self.plugins = {}
self.client.events.on('MessageCreate', self.on_message_create) # Only bind event listeners if we're going to parse commands
self.client.events.on('MessageUpdate', self.on_message_update) if self.config.commands_enabled:
self.client.events.on('MessageCreate', self.on_message_create)
if self.config.command_allow_edit:
self.client.events.on('MessageUpdate', self.on_message_update)
# Stores the last message for every single channel # Stores the last message for every single channel
self.last_message_cache = {} self.last_message_cache = {}
@ -49,7 +64,7 @@ class Bot(object):
else: else:
self.command_matches_re = None self.command_matches_re = None
def handle_message(self, msg): def get_commands_for_message(self, msg):
content = msg.content content = msg.content
if self.config.command_require_mention: if self.config.command_require_mention:
@ -61,20 +76,28 @@ class Bot(object):
)))) ))))
if not match: if not match:
return False raise StopIteration
content = msg.without_mentions.strip() content = msg.without_mentions.strip()
if self.config.command_prefix and not content.startswith(self.config.command_prefix): if self.config.command_prefix and not content.startswith(self.config.command_prefix):
return False raise StopIteration
else:
content = content[len(self.config.command_prefix):]
if not self.command_matches_re or not self.command_matches_re.match(content): if not self.command_matches_re or not self.command_matches_re.match(content):
return False raise StopIteration
for command in self.commands: for command in self.commands:
match = command.compiled_regex.match(content) match = command.compiled_regex.match(content)
if match: if match:
command.execute(msg, match) yield (command, match)
def handle_message(self, msg):
commands = list(self.get_commands_for_message(msg))
if len(commands):
return any((command.execute(msg, match) for command, match in commands))
return False return False
@ -99,7 +122,9 @@ class Bot(object):
if cls.__name__ in self.plugins: if cls.__name__ in self.plugins:
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__] = cls(self) config = self.config.plugin_config_provider(cls.__name__) if self.config.plugin_config_provider else {}
self.plugins[cls.__name__] = cls(self, config)
self.plugins[cls.__name__].load() self.plugins[cls.__name__].load()
self.compute_command_matches_re() self.compute_command_matches_re()

6
disco/bot/command.py

@ -13,7 +13,8 @@ class CommandEvent(object):
class Command(object): class Command(object):
def __init__(self, func, trigger, aliases=None, group=None, is_regex=False): def __init__(self, plugin, func, trigger, aliases=None, group=None, is_regex=False):
self.plugin = plugin
self.func = func self.func = func
self.triggers = [trigger] + (aliases or []) self.triggers = [trigger] + (aliases or [])
@ -21,7 +22,8 @@ class Command(object):
self.is_regex = is_regex self.is_regex = is_regex
def execute(self, msg, match): def execute(self, msg, match):
self.func(CommandEvent(msg, match)) event = CommandEvent(msg, match)
return self.func(event)
@cached_property @cached_property
def compiled_regex(self): def compiled_regex(self):

90
disco/bot/plugin.py

@ -1,46 +1,72 @@
import inspect import inspect
import functools
from disco.bot.command import Command from disco.bot.command import Command
class PluginDeco(object): class PluginDeco(object):
@staticmethod @staticmethod
def listen(event_name): def add_meta_deco(meta):
def deco(f): def deco(f):
if not hasattr(f, 'meta'): if not hasattr(f, 'meta'):
f.meta = [] f.meta = []
f.meta.append({ f.meta.append(meta)
'type': 'listener',
'event_name': event_name,
})
return f return f
return deco return deco
@staticmethod @classmethod
def command(*args, **kwargs): def listen(cls, event_name):
def deco(f): return cls.add_meta_deco({
if not hasattr(f, 'meta'): 'type': 'listener',
f.meta = [] 'event_name': event_name,
})
f.meta.append({
'type': 'command', @classmethod
'args': args, def command(cls, *args, **kwargs):
'kwargs': kwargs, return cls.add_meta_deco({
}) 'type': 'command',
'args': args,
return f 'kwargs': kwargs,
return deco })
@classmethod
def pre_command(cls):
return cls.add_meta_deco({
'type': 'pre_command',
})
@classmethod
def post_command(cls):
return cls.add_meta_deco({
'type': 'post_command',
})
@classmethod
def pre_listener(cls):
return cls.add_meta_deco({
'type': 'pre_listener',
})
@classmethod
def post_listener(cls):
return cls.add_meta_deco({
'type': 'post_listener',
})
class Plugin(PluginDeco): class Plugin(PluginDeco):
def __init__(self, bot): def __init__(self, bot, config):
self.bot = bot self.bot = bot
self.config = config
self.listeners = [] self.listeners = []
self.commands = [] self.commands = []
self._pre = {'command': [], 'listener': []}
self._post = {'command': [], 'listener': []}
for name, member in inspect.getmembers(self, predicate=inspect.ismethod): for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'): if hasattr(member, 'meta'):
for meta in member.meta: for meta in member.meta:
@ -48,12 +74,34 @@ class Plugin(PluginDeco):
self.register_listener(member, meta['event_name']) self.register_listener(member, meta['event_name'])
elif meta['type'] == 'command': elif meta['type'] == 'command':
self.register_command(member, *meta['args'], **meta['kwargs']) self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'):
when, typ = meta['type'].split('_', 1)
self.register_trigger(typ, when, member)
def register_trigger(self, typ, when, func):
getattr(self, '_' + when)[typ].append(func)
def _dispatch(self, typ, func, event):
for pre in self._pre[typ]:
event = pre(event)
if event is None:
return False
result = func(event)
for post in self._post[typ]:
post(event, result)
return True
def register_listener(self, func, name): def register_listener(self, func, name):
func = functools.partial(self._dispatch, 'listener', func)
self.listeners.append(self.bot.client.events.on(name, func)) self.listeners.append(self.bot.client.events.on(name, func))
def register_command(self, func, *args, **kwargs): def register_command(self, func, *args, **kwargs):
self.commands.append(Command(func, *args, **kwargs)) func = functools.partial(self._dispatch, 'command', func)
self.commands.append(Command(self, func, *args, **kwargs))
def destroy(self): def destroy(self):
map(lambda k: k.remove(), self._events) map(lambda k: k.remove(), self._events)

23
disco/gateway/client.py

@ -3,6 +3,7 @@ import gevent
import json import json
import zlib import zlib
import six import six
import ssl
from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket
from disco.gateway.events import GatewayEvent from disco.gateway.events import GatewayEvent
@ -37,6 +38,7 @@ class GatewayClient(LoggingClass):
self.seq = 0 self.seq = 0
self.session_id = None self.session_id = None
self.reconnects = 0 self.reconnects = 0
self.shutting_down = False
# Cached gateway URL # Cached gateway URL
self._cached_gateway_url = None self._cached_gateway_url = None
@ -85,7 +87,7 @@ class GatewayClient(LoggingClass):
self.session_id = ready.session_id self.session_id = ready.session_id
self.reconnects = 0 self.reconnects = 0
def connect(self): def connect_and_run(self):
if not self._cached_gateway_url: if not self._cached_gateway_url:
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json')
@ -98,6 +100,7 @@ class GatewayClient(LoggingClass):
on_close=self.log_on_error('Error in on_close:', self.on_close), on_close=self.log_on_error('Error in on_close:', self.on_close),
) )
self.ws._get_close_args = websocket_get_close_args_override self.ws._get_close_args = websocket_get_close_args_override
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_message(self, ws, msg): def on_message(self, ws, msg):
# Detect zlib and decompress # Detect zlib and decompress
@ -130,6 +133,8 @@ class GatewayClient(LoggingClass):
raise Exception('Unknown packet: {}'.format(data['op'])) raise Exception('Unknown packet: {}'.format(data['op']))
def on_error(self, ws, error): def on_error(self, ws, error):
if isinstance(error, KeyboardInterrupt):
self.shutting_down = True
raise Exception('WS recieved error: %s', error) raise Exception('WS recieved error: %s', error)
def on_open(self, ws): def on_open(self, ws):
@ -145,6 +150,10 @@ class GatewayClient(LoggingClass):
shard=[self.client.sharding['number'], self.client.sharding['total']])) shard=[self.client.sharding['number'], self.client.sharding['total']]))
def on_close(self, ws, code, reason): def on_close(self, ws, code, reason):
if self.shutting_down:
self.log.info('WS Closed: shutting down')
return
self.reconnects += 1 self.reconnects += 1
self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects)
@ -152,19 +161,15 @@ class GatewayClient(LoggingClass):
raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS))
# Don't resume for these error codes # Don't resume for these error codes
if 4000 <= code <= 4010: if code and 4000 <= code <= 4010:
self.session_id = None self.session_id = None
self.log.info('Attempting fresh reconnect')
else:
self.log.info('Attempting resume')
wait_time = self.reconnects * 5 wait_time = self.reconnects * 5
self.log.info('Will attempt to {} after {} seconds', 'resume' if self.session_id else 'reconnect', wait_time) self.log.info('Will attempt to %s after %s seconds', 'resume' if self.session_id else 'reconnect', wait_time)
gevent.sleep(wait_time) gevent.sleep(wait_time)
# Reconnect # Reconnect
self.connect() self.connect_and_run()
def run(self): def run(self):
self.connect() self.connect_and_run()
self.ws.run_forever()

7
disco/gateway/events.py

@ -152,10 +152,15 @@ class MessageDeleteBulk(GatewayEvent):
class PresenceUpdate(GatewayEvent): class PresenceUpdate(GatewayEvent):
class Game(skema.Model):
type = skema.IntType()
name = skema.StringType()
url = skema.StringType(required=False)
user = skema.ModelType(User) user = skema.ModelType(User)
guild_id = skema.SnowflakeType() guild_id = skema.SnowflakeType()
roles = skema.ListType(skema.SnowflakeType()) roles = skema.ListType(skema.SnowflakeType())
game = skema.StringType() game = skema.ModelType(Game)
status = skema.StringType() status = skema.StringType()

6
disco/state.py

@ -31,8 +31,7 @@ class State(object):
self.channels[channel.id] = channel self.channels[channel.id] = channel
def on_guild_update(self, event): def on_guild_update(self, event):
# TODO self.guilds[event.guild.id] = event.guild
pass
def on_guild_delete(self, event): def on_guild_delete(self, event):
if event.guild_id in self.guilds: if event.guild_id in self.guilds:
@ -44,8 +43,7 @@ class State(object):
self.channels[event.channel.id] = event.channel self.channels[event.channel.id] = event.channel
def on_channel_update(self, event): def on_channel_update(self, event):
# TODO self.channels[event.channel.id] = event.channel
pass
def on_channel_delete(self, event): def on_channel_delete(self, event):
if event.channel.id in self.channels: if event.channel.id in self.channels:

5
disco/util/cache.py

@ -2,6 +2,7 @@
def cached_property(f): def cached_property(f):
def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
self.__dict__[f.__name__] = f(self, *args, **kwargs) if not hasattr(self, '__' + f.__name__):
return self.__dict__[f.__name__] setattr(self, '__' + f.__name__, f(self, *args, **kwargs))
return getattr(self, '__' + f.__name__)
return property(deco) return property(deco)

4
requirements.txt

@ -6,7 +6,7 @@ enum34==1.1.6
Flask==0.11.1 Flask==0.11.1
gevent==1.1.2 gevent==1.1.2
greenlet==0.4.10 greenlet==0.4.10
holster==0.0.7 # holster==0.0.7
idna==2.1 idna==2.1
inflection==0.3.1 inflection==0.3.1
ipaddress==1.0.17 ipaddress==1.0.17
@ -19,7 +19,7 @@ pycparser==2.14
pyOpenSSL==16.1.0 pyOpenSSL==16.1.0
requests==2.11.1 requests==2.11.1
six==1.10.0 six==1.10.0
skema==0.0.1 # skema==0.0.1
websocket-client==0.37.0 websocket-client==0.37.0
Werkzeug==0.11.11 Werkzeug==0.11.11
wheel==0.24.0 wheel==0.24.0

34
setup.py

@ -0,0 +1,34 @@
from setuptools import setup, find_packages
from disco import VERSION
with open('requirements.txt') as f:
requirements = f.readlines()
with open('README.md') as f:
readme = f.read()
setup(
name='disco',
author='b1nzy',
url='https://github.com/b1naryth1ef/disco',
version=VERSION,
packages=find_packages(),
license='MIT',
description='A Python library for Discord',
long_description=readme,
include_package_data=True,
install_requires=requirements,
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: MIT License',
'Intended Audience :: Developers',
'Natural Language :: English',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Topic :: Internet',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: Utilities',
])
Loading…
Cancel
Save