Browse Source

plugins, commands, etc

pull/3/head
Andrei 9 years ago
parent
commit
c4d4b40107
  1. 1
      disco/bot/__init__.py
  2. 116
      disco/bot/bot.py
  3. 29
      disco/bot/command.py
  4. 65
      disco/bot/plugin.py
  5. 8
      disco/cli.py
  6. 7
      disco/client.py
  7. 14
      disco/gateway/client.py
  8. 12
      disco/gateway/events.py
  9. 52
      disco/state.py
  10. 5
      disco/types/base.py
  11. 14
      disco/types/channel.py
  12. 22
      disco/types/guild.py
  13. 54
      disco/types/message.py
  14. 4
      disco/types/user.py
  15. 4
      disco/types/voice.py
  16. 18
      disco/util/__init__.py
  17. 7
      disco/util/cache.py
  18. 18
      examples/basic_plugin.py

1
disco/bot/__init__.py

@ -0,0 +1 @@
from disco.bot.bot import Bot

116
disco/bot/bot.py

@ -0,0 +1,116 @@
import re
class BotConfig(object):
# Whether the bot must be mentioned to respond to a command
command_require_mention = True
# Rules about what mentions trigger the bot
command_mention_rules = {
# 'here': False,
'everyone': False,
'role': True,
'user': True,
}
# The prefix required for EVERY command
command_prefix = ''
# Whether an edited message can trigger a command
command_allow_edit = True
class Bot(object):
def __init__(self, client, config=None):
self.client = client
self.config = config or BotConfig()
self.plugins = {}
self.client.events.on('MessageCreate', self.on_message_create)
self.client.events.on('MessageUpdate', self.on_message_update)
# Stores the last message for every single channel
self.last_message_cache = {}
# Stores a giant regex matcher for all commands
self.command_matches_re = None
@property
def commands(self):
for plugin in self.plugins.values():
for command in plugin.commands:
yield command
def compute_command_matches_re(self):
re_str = '|'.join(command.regex for command in self.commands)
print re_str
if re_str:
self.command_matches_re = re.compile(re_str)
else:
self.command_matches_re = None
def handle_message(self, msg):
content = msg.content
if self.config.command_require_mention:
match = any((
self.config.command_mention_rules['user'] and msg.is_mentioned(self.client.state.me),
self.config.command_mention_rules['everyone'] and msg.mention_everyone,
self.config.command_mention_rules['role'] and any(map(msg.is_mentioned,
msg.guild.get_member(self.client.state.me).roles
))))
if not match:
return False
content = msg.without_mentions.strip()
if self.config.command_prefix and not content.startswith(self.config.command_prefix):
return False
if not self.command_matches_re or not self.command_matches_re.match(content):
return False
for command in self.commands:
if command.compiled_regex.match(content):
command.execute(msg)
return False
def on_message_create(self, event):
if self.config.command_allow_edit:
self.last_message_cache[event.message.channel_id] = (event.message, False)
self.handle_message(event.message)
def on_message_update(self, event):
if self.config.command_allow_edit:
msg = self.last_message_cache.get(event.message.channel_id)
if msg and event.message.id == msg[0].id:
triggered = msg[1]
if not triggered:
triggered = self.handle_message(event.message)
self.last_message_cache[event.message.channel_id] = (event.message, triggered)
def add_plugin(self, cls):
if cls.__name__ in self.plugins:
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__] = cls(self)
self.plugins[cls.__name__].load()
self.compute_command_matches_re()
def rmv_plugin(self, cls):
if cls.__name__ not in self.plugins:
raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__].unload()
self.plugins[cls.__name__].destroy()
del self.plugins[cls.__name__]
self.compute_command_matches_re()
def run_forever(self):
self.client.run_forever()

29
disco/bot/command.py

@ -0,0 +1,29 @@
import re
from disco.util.cache import cached_property
ARGS_REGEX = '( (.*)$|$)'
class Command(object):
def __init__(self, func, trigger, aliases=None, group=None, is_regex=False):
self.func = func
self.triggers = [trigger] + (aliases or [])
self.group = group
self.is_regex = is_regex
def execute(self, msg):
self.func(msg)
@cached_property
def compiled_regex(self):
return re.compile(self.regex)
@property
def regex(self):
if self.is_regex:
return '|'.join(self.triggers)
else:
group = self.group + ' ' if self.group else ''
return '|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX

65
disco/bot/plugin.py

@ -0,0 +1,65 @@
import inspect
from disco.bot.command import Command
class PluginDeco(object):
@staticmethod
def listen(event_name):
def deco(f):
if not hasattr(f, 'meta'):
f.meta = []
f.meta.append({
'type': 'listener',
'event_name': event_name,
})
return f
return deco
@staticmethod
def command(*args, **kwargs):
def deco(f):
if not hasattr(f, 'meta'):
f.meta = []
f.meta.append({
'type': 'command',
'args': args,
'kwargs': kwargs,
})
return f
return deco
class Plugin(PluginDeco):
def __init__(self, bot):
self.bot = bot
self.listeners = []
self.commands = []
for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'):
for meta in member.meta:
if meta['type'] == 'listener':
self.register_listener(member, meta['event_name'])
elif meta['type'] == 'command':
self.register_command(member, *meta['args'], **meta['kwargs'])
def register_listener(self, func, name):
self.listeners.append(self.bot.client.events.on(name, func))
def register_command(self, func, *args, **kwargs):
self.commands.append(Command(func, *args, **kwargs))
def destroy(self):
map(lambda k: k.remove(), self._events)
def load(self):
pass
def unload(self):
pass

8
disco/cli.py

@ -3,14 +3,14 @@ import argparse
from gevent import monkey
monkey.patch_all()
parser = argparse.ArgumentParser()
parser.add_argument('--token', help='Bot Authentication Token', required=True)
logging.basicConfig(level=logging.INFO)
def main():
monkey.patch_all()
def disco_main():
args = parser.parse_args()
from disco.util.token import is_valid_token
@ -20,7 +20,7 @@ def main():
return
from disco.client import DiscoClient
DiscoClient(args.token).run_forever()
return DiscoClient(args.token)
if __name__ == '__main__':
main()
disco_main().run_forever()

7
disco/client.py

@ -1,5 +1,9 @@
import logging
import gevent
from holster.emitter import Emitter
from disco.state import State
from disco.api.client import APIClient
from disco.gateway.client import GatewayClient
@ -12,6 +16,9 @@ class DiscoClient(object):
self.token = token
self.sharding = sharding or {'number': 0, 'total': 1}
self.events = Emitter(gevent.spawn)
self.state = State(self)
self.api = APIClient(self)
self.gw = GatewayClient(self)

14
disco/gateway/client.py

@ -3,11 +3,8 @@ import gevent
import json
import zlib
from holster.emitter import Emitter
# from holster.util import SimpleObject
from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket
from disco.gateway.events import GatewayEvent, Ready
from disco.gateway.events import GatewayEvent
from disco.util.logging import LoggingClass
GATEWAY_VERSION = 6
@ -28,9 +25,8 @@ class GatewayClient(LoggingClass):
def __init__(self, client):
super(GatewayClient, self).__init__()
self.client = client
self.emitter = Emitter(gevent.spawn)
self.emitter.on(Ready, self.on_ready)
self.client.events.on('Ready', self.on_ready)
# Websocket connection
self.ws = None
@ -60,9 +56,9 @@ class GatewayClient(LoggingClass):
gevent.sleep(interval / 1000)
def handle_dispatch(self, packet):
cls, obj = GatewayEvent.from_dispatch(packet)
self.log.info('Dispatching %s', cls)
self.emitter.emit(cls, obj)
obj = GatewayEvent.from_dispatch(self.client, packet)
self.log.info('Dispatching %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet):
pass

12
disco/gateway/events.py

@ -1,17 +1,25 @@
import inflection
import skema
from disco.util import recursive_find_matching
from disco.types.base import BaseType
from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState
class GatewayEvent(skema.Model):
@staticmethod
def from_dispatch(obj):
def from_dispatch(client, obj):
cls = globals().get(inflection.camelize(obj['t'].lower()))
if not cls:
raise Exception('Could not find cls for {}'.format(obj['t']))
return cls, cls.create(obj['d'])
obj = cls.create(obj['d'])
# TODO: use skema info
for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)):
item.client = client
return obj
@classmethod
def create(cls, obj):

52
disco/state.py

@ -0,0 +1,52 @@
class State(object):
def __init__(self, client):
self.client = client
self.me = None
self.channels = {}
self.guilds = {}
self.client.events.on('Ready', self.on_ready)
# Guilds
self.client.events.on('GuildCreate', self.on_guild_create)
self.client.events.on('GuildUpdate', self.on_guild_update)
self.client.events.on('GuildDelete', self.on_guild_delete)
# Channels
self.client.events.on('ChannelCreate', self.on_channel_create)
self.client.events.on('ChannelUpdate', self.on_channel_update)
self.client.events.on('ChannelDelete', self.on_channel_delete)
def on_ready(self, event):
self.me = event.user
def on_guild_create(self, event):
self.guilds[event.guild.id] = event.guild
for channel in event.guild.channels:
self.channels[channel.id] = channel
def on_guild_update(self, event):
# TODO
pass
def on_guild_delete(self, event):
if event.guild_id in self.guilds:
del self.guilds[event.guild_id]
# CHANNELS?
def on_channel_create(self, event):
self.channels[event.channel.id] = event.channel
def on_channel_update(self, event):
# TODO
pass
def on_channel_delete(self, event):
if event.channel.id in self.channels:
del self.channels[event.channel.id]

5
disco/types/base.py

@ -0,0 +1,5 @@
import skema
class BaseType(skema.Model):
pass

14
disco/types/channel.py

@ -2,7 +2,8 @@ import skema
from holster.enum import Enum
# from disco.types.guild import Guild
from disco.util.cache import cached_property
from disco.types.base import BaseType
from disco.types.user import User
@ -19,7 +20,7 @@ PermissionOverwriteType = Enum(
)
class PermissionOverwrite(skema.Model):
class PermissionOverwrite(BaseType):
id = skema.SnowflakeType()
type = skema.StringType(choices=PermissionOverwriteType.ALL_VALUES)
@ -27,8 +28,9 @@ class PermissionOverwrite(skema.Model):
deny = skema.IntType()
class Channel(skema.Model):
class Channel(BaseType):
id = skema.SnowflakeType()
guild_id = skema.SnowflakeType(required=False)
name = skema.StringType()
topic = skema.StringType()
@ -40,3 +42,9 @@ class Channel(skema.Model):
type = skema.IntType(choices=ChannelType.ALL_VALUES)
permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite))
@cached_property
def guild(self):
print self.guild_id
print self.client.state.guilds
return self.client.state.guilds.get(self.guild_id)

22
disco/types/guild.py

@ -1,12 +1,14 @@
import skema
from disco.util.cache import cached_property
from disco.types.base import BaseType
from disco.util.types import PreHookType
from disco.types.user import User
from disco.types.voice import VoiceState
from disco.types.channel import Channel
class Emoji(skema.Model):
class Emoji(BaseType):
id = skema.SnowflakeType()
name = skema.StringType()
require_colons = skema.BooleanType()
@ -14,7 +16,7 @@ class Emoji(skema.Model):
roles = skema.ListType(skema.SnowflakeType())
class Role(skema.Model):
class Role(BaseType):
id = skema.SnowflakeType()
name = skema.StringType()
hoist = skema.BooleanType()
@ -24,7 +26,7 @@ class Role(skema.Model):
position = skema.IntType()
class GuildMember(skema.Model):
class GuildMember(BaseType):
user = skema.ModelType(User)
mute = skema.BooleanType()
deaf = skema.BooleanType()
@ -32,7 +34,7 @@ class GuildMember(skema.Model):
roles = skema.ListType(skema.SnowflakeType())
class Guild(skema.Model):
class Guild(BaseType):
id = skema.SnowflakeType()
owner_id = skema.SnowflakeType()
@ -56,3 +58,15 @@ class Guild(skema.Model):
channels = skema.ListType(skema.ModelType(Channel))
roles = skema.ListType(skema.ModelType(Role))
emojis = skema.ListType(skema.ModelType(Emoji))
@cached_property
def members_dict(self):
return {i.user.id: i for i in self.members}
def get_member(self, user):
return self.members_dict.get(user.id)
def validate_channels(self, ctx):
if self.channels:
for channel in self.channels:
channel.guild_id = self.id

54
disco/types/message.py

@ -1,17 +1,21 @@
import re
import skema
from disco.util.cache import cached_property
from disco.util.types import PreHookType
from disco.types.base import BaseType
from disco.types.user import User
from disco.types.guild import Role
class MessageEmbed(skema.Model):
class MessageEmbed(BaseType):
title = skema.StringType()
type = skema.StringType()
description = skema.StringType()
url = skema.StringType()
class MessageAttachment(skema.Model):
class MessageAttachment(BaseType):
id = skema.SnowflakeType()
filename = skema.StringType()
url = skema.StringType()
@ -21,7 +25,7 @@ class MessageAttachment(skema.Model):
width = skema.IntType()
class Message(skema.Model):
class Message(BaseType):
id = skema.SnowflakeType()
channel_id = skema.SnowflakeType()
@ -42,3 +46,47 @@ class Message(skema.Model):
embeds = skema.ListType(skema.ModelType(MessageEmbed))
attachment = skema.ListType(skema.ModelType(MessageAttachment))
@cached_property
def guild(self):
return self.channel.guild
@cached_property
def channel(self):
print self.client.state.channels
return self.client.state.channels.get(self.channel_id)
@cached_property
def mention_users(self):
return [i.id for i in self.mentions]
@cached_property
def mention_users_dict(self):
return {i.id: i for i in self.mentions}
def is_mentioned(self, entity):
if isinstance(entity, User):
return entity.id in self.mention_users
elif isinstance(entity, Role):
return entity.id in self.mention_roles
else:
raise Exception('Unknown entity: {}'.format(entity))
@cached_property
def without_mentions(self):
return self.replace_mentions(
lambda u: '',
lambda r: '')
def replace_mentions(self, user_replace, role_replace):
if not self.mentions and not self.mention_roles:
return
def replace(match):
id = match.group(0)
if id in self.mention_roles:
return role_replace(id)
else:
return user_replace(self.mention_users_dict.get(id))
return re.sub('<@!?([0-9]+)>', replace, self.content)

4
disco/types/user.py

@ -1,7 +1,9 @@
import skema
from disco.types.base import BaseType
class User(skema.Model):
class User(BaseType):
id = skema.SnowflakeType()
username = skema.StringType()

4
disco/types/voice.py

@ -1,5 +1,7 @@
import skema
from disco.types.base import BaseType
class VoiceState(skema.Model):
class VoiceState(BaseType):
id = skema.SnowflakeType()

18
disco/util/__init__.py

@ -0,0 +1,18 @@
def recursive_find_matching(base, match_clause):
result = []
if hasattr(base, '__dict__'):
values = base.__dict__.values()
else:
values = list(base)
for v in values:
if match_clause(v):
result.append(v)
if hasattr(v, '__dict__') or hasattr(v, '__iter__'):
result += recursive_find_matching(v, match_clause)
return result

7
disco/util/cache.py

@ -0,0 +1,7 @@
def cached_property(f):
def deco(self, *args, **kwargs):
self.__dict__[f.__name__] = f(self, *args, **kwargs)
return self.__dict__[f.__name__]
return property(deco)

18
examples/basic_plugin.py

@ -0,0 +1,18 @@
from disco.cli import disco_main
from disco.bot import Bot
from disco.bot.plugin import Plugin
class BasicPlugin(Plugin):
@Plugin.listen('MessageCreate')
def on_message_create(self, event):
print 'Message Created: {}'.format(event.message.content)
@Plugin.command('test')
def on_test_command(self, event):
print 'wtf'
if __name__ == '__main__':
bot = Bot(disco_main())
bot.add_plugin(BasicPlugin)
bot.run_forever()
Loading…
Cancel
Save