Browse Source

plugins, commands, etc

pull/3/head
Andrei 9 years ago
parent
commit
c4d4b40107
  1. 1
      disco/bot/__init__.py
  2. 116
      disco/bot/bot.py
  3. 29
      disco/bot/command.py
  4. 65
      disco/bot/plugin.py
  5. 8
      disco/cli.py
  6. 7
      disco/client.py
  7. 14
      disco/gateway/client.py
  8. 12
      disco/gateway/events.py
  9. 52
      disco/state.py
  10. 5
      disco/types/base.py
  11. 14
      disco/types/channel.py
  12. 22
      disco/types/guild.py
  13. 54
      disco/types/message.py
  14. 4
      disco/types/user.py
  15. 4
      disco/types/voice.py
  16. 18
      disco/util/__init__.py
  17. 7
      disco/util/cache.py
  18. 18
      examples/basic_plugin.py

1
disco/bot/__init__.py

@ -0,0 +1 @@
from disco.bot.bot import Bot

116
disco/bot/bot.py

@ -0,0 +1,116 @@
import re
class BotConfig(object):
# Whether the bot must be mentioned to respond to a command
command_require_mention = True
# Rules about what mentions trigger the bot
command_mention_rules = {
# 'here': False,
'everyone': False,
'role': True,
'user': True,
}
# The prefix required for EVERY command
command_prefix = ''
# Whether an edited message can trigger a command
command_allow_edit = True
class Bot(object):
def __init__(self, client, config=None):
self.client = client
self.config = config or BotConfig()
self.plugins = {}
self.client.events.on('MessageCreate', self.on_message_create)
self.client.events.on('MessageUpdate', self.on_message_update)
# Stores the last message for every single channel
self.last_message_cache = {}
# Stores a giant regex matcher for all commands
self.command_matches_re = None
@property
def commands(self):
for plugin in self.plugins.values():
for command in plugin.commands:
yield command
def compute_command_matches_re(self):
re_str = '|'.join(command.regex for command in self.commands)
print re_str
if re_str:
self.command_matches_re = re.compile(re_str)
else:
self.command_matches_re = None
def handle_message(self, msg):
content = msg.content
if self.config.command_require_mention:
match = any((
self.config.command_mention_rules['user'] and msg.is_mentioned(self.client.state.me),
self.config.command_mention_rules['everyone'] and msg.mention_everyone,
self.config.command_mention_rules['role'] and any(map(msg.is_mentioned,
msg.guild.get_member(self.client.state.me).roles
))))
if not match:
return False
content = msg.without_mentions.strip()
if self.config.command_prefix and not content.startswith(self.config.command_prefix):
return False
if not self.command_matches_re or not self.command_matches_re.match(content):
return False
for command in self.commands:
if command.compiled_regex.match(content):
command.execute(msg)
return False
def on_message_create(self, event):
if self.config.command_allow_edit:
self.last_message_cache[event.message.channel_id] = (event.message, False)
self.handle_message(event.message)
def on_message_update(self, event):
if self.config.command_allow_edit:
msg = self.last_message_cache.get(event.message.channel_id)
if msg and event.message.id == msg[0].id:
triggered = msg[1]
if not triggered:
triggered = self.handle_message(event.message)
self.last_message_cache[event.message.channel_id] = (event.message, triggered)
def add_plugin(self, cls):
if cls.__name__ in self.plugins:
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__] = cls(self)
self.plugins[cls.__name__].load()
self.compute_command_matches_re()
def rmv_plugin(self, cls):
if cls.__name__ not in self.plugins:
raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__].unload()
self.plugins[cls.__name__].destroy()
del self.plugins[cls.__name__]
self.compute_command_matches_re()
def run_forever(self):
self.client.run_forever()

29
disco/bot/command.py

@ -0,0 +1,29 @@
import re
from disco.util.cache import cached_property
ARGS_REGEX = '( (.*)$|$)'
class Command(object):
def __init__(self, func, trigger, aliases=None, group=None, is_regex=False):
self.func = func
self.triggers = [trigger] + (aliases or [])
self.group = group
self.is_regex = is_regex
def execute(self, msg):
self.func(msg)
@cached_property
def compiled_regex(self):
return re.compile(self.regex)
@property
def regex(self):
if self.is_regex:
return '|'.join(self.triggers)
else:
group = self.group + ' ' if self.group else ''
return '|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX

65
disco/bot/plugin.py

@ -0,0 +1,65 @@
import inspect
from disco.bot.command import Command
class PluginDeco(object):
@staticmethod
def listen(event_name):
def deco(f):
if not hasattr(f, 'meta'):
f.meta = []
f.meta.append({
'type': 'listener',
'event_name': event_name,
})
return f
return deco
@staticmethod
def command(*args, **kwargs):
def deco(f):
if not hasattr(f, 'meta'):
f.meta = []
f.meta.append({
'type': 'command',
'args': args,
'kwargs': kwargs,
})
return f
return deco
class Plugin(PluginDeco):
def __init__(self, bot):
self.bot = bot
self.listeners = []
self.commands = []
for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'):
for meta in member.meta:
if meta['type'] == 'listener':
self.register_listener(member, meta['event_name'])
elif meta['type'] == 'command':
self.register_command(member, *meta['args'], **meta['kwargs'])
def register_listener(self, func, name):
self.listeners.append(self.bot.client.events.on(name, func))
def register_command(self, func, *args, **kwargs):
self.commands.append(Command(func, *args, **kwargs))
def destroy(self):
map(lambda k: k.remove(), self._events)
def load(self):
pass
def unload(self):
pass

8
disco/cli.py

@ -3,14 +3,14 @@ import argparse
from gevent import monkey from gevent import monkey
monkey.patch_all()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--token', help='Bot Authentication Token', required=True) parser.add_argument('--token', help='Bot Authentication Token', required=True)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def main(): def disco_main():
monkey.patch_all()
args = parser.parse_args() args = parser.parse_args()
from disco.util.token import is_valid_token from disco.util.token import is_valid_token
@ -20,7 +20,7 @@ def main():
return return
from disco.client import DiscoClient from disco.client import DiscoClient
DiscoClient(args.token).run_forever() return DiscoClient(args.token)
if __name__ == '__main__': if __name__ == '__main__':
main() disco_main().run_forever()

7
disco/client.py

@ -1,5 +1,9 @@
import logging import logging
import gevent
from holster.emitter import Emitter
from disco.state import State
from disco.api.client import APIClient from disco.api.client import APIClient
from disco.gateway.client import GatewayClient from disco.gateway.client import GatewayClient
@ -12,6 +16,9 @@ class DiscoClient(object):
self.token = token self.token = token
self.sharding = sharding or {'number': 0, 'total': 1} self.sharding = sharding or {'number': 0, 'total': 1}
self.events = Emitter(gevent.spawn)
self.state = State(self)
self.api = APIClient(self) self.api = APIClient(self)
self.gw = GatewayClient(self) self.gw = GatewayClient(self)

14
disco/gateway/client.py

@ -3,11 +3,8 @@ import gevent
import json import json
import zlib import zlib
from holster.emitter import Emitter
# from holster.util import SimpleObject
from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket
from disco.gateway.events import GatewayEvent, Ready from disco.gateway.events import GatewayEvent
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
GATEWAY_VERSION = 6 GATEWAY_VERSION = 6
@ -28,9 +25,8 @@ class GatewayClient(LoggingClass):
def __init__(self, client): def __init__(self, client):
super(GatewayClient, self).__init__() super(GatewayClient, self).__init__()
self.client = client self.client = client
self.emitter = Emitter(gevent.spawn)
self.emitter.on(Ready, self.on_ready) self.client.events.on('Ready', self.on_ready)
# Websocket connection # Websocket connection
self.ws = None self.ws = None
@ -60,9 +56,9 @@ class GatewayClient(LoggingClass):
gevent.sleep(interval / 1000) gevent.sleep(interval / 1000)
def handle_dispatch(self, packet): def handle_dispatch(self, packet):
cls, obj = GatewayEvent.from_dispatch(packet) obj = GatewayEvent.from_dispatch(self.client, packet)
self.log.info('Dispatching %s', cls) self.log.info('Dispatching %s', obj.__class__.__name__)
self.emitter.emit(cls, obj) self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet): def handle_heartbeat(self, packet):
pass pass

12
disco/gateway/events.py

@ -1,17 +1,25 @@
import inflection import inflection
import skema import skema
from disco.util import recursive_find_matching
from disco.types.base import BaseType
from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState
class GatewayEvent(skema.Model): class GatewayEvent(skema.Model):
@staticmethod @staticmethod
def from_dispatch(obj): def from_dispatch(client, obj):
cls = globals().get(inflection.camelize(obj['t'].lower())) cls = globals().get(inflection.camelize(obj['t'].lower()))
if not cls: if not cls:
raise Exception('Could not find cls for {}'.format(obj['t'])) raise Exception('Could not find cls for {}'.format(obj['t']))
return cls, cls.create(obj['d']) obj = cls.create(obj['d'])
# TODO: use skema info
for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)):
item.client = client
return obj
@classmethod @classmethod
def create(cls, obj): def create(cls, obj):

52
disco/state.py

@ -0,0 +1,52 @@
class State(object):
def __init__(self, client):
self.client = client
self.me = None
self.channels = {}
self.guilds = {}
self.client.events.on('Ready', self.on_ready)
# Guilds
self.client.events.on('GuildCreate', self.on_guild_create)
self.client.events.on('GuildUpdate', self.on_guild_update)
self.client.events.on('GuildDelete', self.on_guild_delete)
# Channels
self.client.events.on('ChannelCreate', self.on_channel_create)
self.client.events.on('ChannelUpdate', self.on_channel_update)
self.client.events.on('ChannelDelete', self.on_channel_delete)
def on_ready(self, event):
self.me = event.user
def on_guild_create(self, event):
self.guilds[event.guild.id] = event.guild
for channel in event.guild.channels:
self.channels[channel.id] = channel
def on_guild_update(self, event):
# TODO
pass
def on_guild_delete(self, event):
if event.guild_id in self.guilds:
del self.guilds[event.guild_id]
# CHANNELS?
def on_channel_create(self, event):
self.channels[event.channel.id] = event.channel
def on_channel_update(self, event):
# TODO
pass
def on_channel_delete(self, event):
if event.channel.id in self.channels:
del self.channels[event.channel.id]

5
disco/types/base.py

@ -0,0 +1,5 @@
import skema
class BaseType(skema.Model):
pass

14
disco/types/channel.py

@ -2,7 +2,8 @@ import skema
from holster.enum import Enum from holster.enum import Enum
# from disco.types.guild import Guild from disco.util.cache import cached_property
from disco.types.base import BaseType
from disco.types.user import User from disco.types.user import User
@ -19,7 +20,7 @@ PermissionOverwriteType = Enum(
) )
class PermissionOverwrite(skema.Model): class PermissionOverwrite(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
type = skema.StringType(choices=PermissionOverwriteType.ALL_VALUES) type = skema.StringType(choices=PermissionOverwriteType.ALL_VALUES)
@ -27,8 +28,9 @@ class PermissionOverwrite(skema.Model):
deny = skema.IntType() deny = skema.IntType()
class Channel(skema.Model): class Channel(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
guild_id = skema.SnowflakeType(required=False)
name = skema.StringType() name = skema.StringType()
topic = skema.StringType() topic = skema.StringType()
@ -40,3 +42,9 @@ class Channel(skema.Model):
type = skema.IntType(choices=ChannelType.ALL_VALUES) type = skema.IntType(choices=ChannelType.ALL_VALUES)
permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite)) permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite))
@cached_property
def guild(self):
print self.guild_id
print self.client.state.guilds
return self.client.state.guilds.get(self.guild_id)

22
disco/types/guild.py

@ -1,12 +1,14 @@
import skema import skema
from disco.util.cache import cached_property
from disco.types.base import BaseType
from disco.util.types import PreHookType from disco.util.types import PreHookType
from disco.types.user import User from disco.types.user import User
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.channel import Channel from disco.types.channel import Channel
class Emoji(skema.Model): class Emoji(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
name = skema.StringType() name = skema.StringType()
require_colons = skema.BooleanType() require_colons = skema.BooleanType()
@ -14,7 +16,7 @@ class Emoji(skema.Model):
roles = skema.ListType(skema.SnowflakeType()) roles = skema.ListType(skema.SnowflakeType())
class Role(skema.Model): class Role(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
name = skema.StringType() name = skema.StringType()
hoist = skema.BooleanType() hoist = skema.BooleanType()
@ -24,7 +26,7 @@ class Role(skema.Model):
position = skema.IntType() position = skema.IntType()
class GuildMember(skema.Model): class GuildMember(BaseType):
user = skema.ModelType(User) user = skema.ModelType(User)
mute = skema.BooleanType() mute = skema.BooleanType()
deaf = skema.BooleanType() deaf = skema.BooleanType()
@ -32,7 +34,7 @@ class GuildMember(skema.Model):
roles = skema.ListType(skema.SnowflakeType()) roles = skema.ListType(skema.SnowflakeType())
class Guild(skema.Model): class Guild(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
owner_id = skema.SnowflakeType() owner_id = skema.SnowflakeType()
@ -56,3 +58,15 @@ class Guild(skema.Model):
channels = skema.ListType(skema.ModelType(Channel)) channels = skema.ListType(skema.ModelType(Channel))
roles = skema.ListType(skema.ModelType(Role)) roles = skema.ListType(skema.ModelType(Role))
emojis = skema.ListType(skema.ModelType(Emoji)) emojis = skema.ListType(skema.ModelType(Emoji))
@cached_property
def members_dict(self):
return {i.user.id: i for i in self.members}
def get_member(self, user):
return self.members_dict.get(user.id)
def validate_channels(self, ctx):
if self.channels:
for channel in self.channels:
channel.guild_id = self.id

54
disco/types/message.py

@ -1,17 +1,21 @@
import re
import skema import skema
from disco.util.cache import cached_property
from disco.util.types import PreHookType from disco.util.types import PreHookType
from disco.types.base import BaseType
from disco.types.user import User from disco.types.user import User
from disco.types.guild import Role
class MessageEmbed(skema.Model): class MessageEmbed(BaseType):
title = skema.StringType() title = skema.StringType()
type = skema.StringType() type = skema.StringType()
description = skema.StringType() description = skema.StringType()
url = skema.StringType() url = skema.StringType()
class MessageAttachment(skema.Model): class MessageAttachment(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
filename = skema.StringType() filename = skema.StringType()
url = skema.StringType() url = skema.StringType()
@ -21,7 +25,7 @@ class MessageAttachment(skema.Model):
width = skema.IntType() width = skema.IntType()
class Message(skema.Model): class Message(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
channel_id = skema.SnowflakeType() channel_id = skema.SnowflakeType()
@ -42,3 +46,47 @@ class Message(skema.Model):
embeds = skema.ListType(skema.ModelType(MessageEmbed)) embeds = skema.ListType(skema.ModelType(MessageEmbed))
attachment = skema.ListType(skema.ModelType(MessageAttachment)) attachment = skema.ListType(skema.ModelType(MessageAttachment))
@cached_property
def guild(self):
return self.channel.guild
@cached_property
def channel(self):
print self.client.state.channels
return self.client.state.channels.get(self.channel_id)
@cached_property
def mention_users(self):
return [i.id for i in self.mentions]
@cached_property
def mention_users_dict(self):
return {i.id: i for i in self.mentions}
def is_mentioned(self, entity):
if isinstance(entity, User):
return entity.id in self.mention_users
elif isinstance(entity, Role):
return entity.id in self.mention_roles
else:
raise Exception('Unknown entity: {}'.format(entity))
@cached_property
def without_mentions(self):
return self.replace_mentions(
lambda u: '',
lambda r: '')
def replace_mentions(self, user_replace, role_replace):
if not self.mentions and not self.mention_roles:
return
def replace(match):
id = match.group(0)
if id in self.mention_roles:
return role_replace(id)
else:
return user_replace(self.mention_users_dict.get(id))
return re.sub('<@!?([0-9]+)>', replace, self.content)

4
disco/types/user.py

@ -1,7 +1,9 @@
import skema import skema
from disco.types.base import BaseType
class User(skema.Model):
class User(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
username = skema.StringType() username = skema.StringType()

4
disco/types/voice.py

@ -1,5 +1,7 @@
import skema import skema
from disco.types.base import BaseType
class VoiceState(skema.Model):
class VoiceState(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()

18
disco/util/__init__.py

@ -0,0 +1,18 @@
def recursive_find_matching(base, match_clause):
result = []
if hasattr(base, '__dict__'):
values = base.__dict__.values()
else:
values = list(base)
for v in values:
if match_clause(v):
result.append(v)
if hasattr(v, '__dict__') or hasattr(v, '__iter__'):
result += recursive_find_matching(v, match_clause)
return result

7
disco/util/cache.py

@ -0,0 +1,7 @@
def cached_property(f):
def deco(self, *args, **kwargs):
self.__dict__[f.__name__] = f(self, *args, **kwargs)
return self.__dict__[f.__name__]
return property(deco)

18
examples/basic_plugin.py

@ -0,0 +1,18 @@
from disco.cli import disco_main
from disco.bot import Bot
from disco.bot.plugin import Plugin
class BasicPlugin(Plugin):
@Plugin.listen('MessageCreate')
def on_message_create(self, event):
print 'Message Created: {}'.format(event.message.content)
@Plugin.command('test')
def on_test_command(self, event):
print 'wtf'
if __name__ == '__main__':
bot = Bot(disco_main())
bot.add_plugin(BasicPlugin)
bot.run_forever()
Loading…
Cancel
Save