diff --git a/disco/api/http.py b/disco/api/http.py index 28cb892..df699c3 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -124,7 +124,6 @@ class HTTPClient(LoggingClass): # Make the actual request url = self.BASE_URL + route[1].format(**args) - print route[0].value, url, kwargs r = requests.request(route[0].value, url, **kwargs) # Update rate limiter diff --git a/disco/bot/bot.py b/disco/bot/bot.py index bce1fe1..e802f35 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -55,7 +55,7 @@ class Bot(object): @property def commands(self): for plugin in self.plugins.values(): - for command in plugin.commands: + for command in plugin.commands.values(): yield command def compute_command_matches_re(self): diff --git a/disco/bot/parser.py b/disco/bot/parser.py index b90fdfe..2394ce4 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -74,7 +74,7 @@ class ArgumentSet(object): parsed = [] for index, arg in enumerate(self.args): - if not arg.required and index + arg.true_count <= len(rawargs): + if not arg.required and index + arg.true_count > len(rawargs): continue if arg.count == 0: @@ -94,7 +94,7 @@ class ArgumentSet(object): if arg.count == 1: raw = raw[0] - if not arg.types or arg.types == ['str'] and isinstance(raw, list): + if (not arg.types or arg.types == ['str']) and isinstance(raw, list): raw = ' '.join(raw) parsed.append(raw) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index bac5145..cbd5c2c 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -66,7 +66,7 @@ class Plugin(LoggingClass, PluginDeco): self.config = config self.listeners = [] - self.commands = [] + self.commands = {} self._pre = {'command': [], 'listener': []} self._post = {'command': [], 'listener': []} @@ -111,8 +111,8 @@ class Plugin(LoggingClass, PluginDeco): self.listeners.append(self.bot.client.events.on(name, func)) def register_command(self, func, *args, **kwargs): - func = functools.partial(self._dispatch, 'command', func) - self.commands.append(Command(self, func, *args, **kwargs)) + wrapped = functools.partial(self._dispatch, 'command', func) + self.commands[func.__name__] = Command(self, func, *args, **kwargs) def destroy(self): map(lambda k: k.remove(), self._events) diff --git a/disco/cli.py b/disco/cli.py index 5c05fb9..17975f2 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -1,3 +1,5 @@ +from __future__ import print_function + import logging import argparse @@ -16,7 +18,7 @@ def disco_main(): from disco.util.token import is_valid_token if not is_valid_token(args.token): - print 'Invalid token passed' + print('Invalid token passed') return from disco.client import DiscoClient diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 90b167d..cca0a1d 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -106,7 +106,7 @@ class GatewayClient(LoggingClass): def on_message(self, ws, msg): # Detect zlib and decompress if msg[0] != '{': - msg = zlib.decompress(msg, 15, TEN_MEGABYTES) + msg = zlib.decompress(msg, 15, TEN_MEGABYTES).decode("utf-8") try: data = loads(msg) diff --git a/disco/state.py b/disco/state.py index 439ab85..532afd4 100644 --- a/disco/state.py +++ b/disco/state.py @@ -38,6 +38,9 @@ class State(object): self.client.events.on('GuildUpdate', self.on_guild_update) self.client.events.on('GuildDelete', self.on_guild_delete) + # TODO: guild members + # TODO: guild roles + # Channels self.client.events.on('ChannelCreate', self.on_channel_create) self.client.events.on('ChannelUpdate', self.on_channel_update) @@ -80,6 +83,9 @@ class State(object): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) + for member in event.guild.members.values(): + self.users[member.user.id] = member.user + def on_guild_update(self, event): self.guilds[event.guild.id].update(event.guild) diff --git a/disco/types/base.py b/disco/types/base.py index 28efe7f..3a8a6f6 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -5,9 +5,6 @@ from disco.util import skema_find_recursive_by_type class BaseType(skema.Model): - def on_create(self): - pass - def update(self, other): for name, field in other.__class__._fields.items(): value = getattr(other, name) @@ -25,7 +22,6 @@ class BaseType(skema.Model): item.client = client obj.client = client - obj.on_create() return obj @classmethod diff --git a/disco/types/message.py b/disco/types/message.py index fdd4ece..fd3645f 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,11 +1,11 @@ import re import skema +from disco.util import to_snowflake from disco.util.cache import cached_property from disco.util.types import PreHookType, ListToDictType from disco.types.base import BaseType from disco.types.user import User -from disco.types.guild import Role class MessageEmbed(BaseType): @@ -65,14 +65,8 @@ class Message(BaseType): return self.client.api.channels_messages_delete(self.channel_id, self.id) def is_mentioned(self, entity): - if isinstance(entity, User): - return entity.id in self.mentions - elif isinstance(entity, Role): - return entity.id in self.mention_roles - elif isinstance(entity, long): - return entity in self.mentions or entity in self.mention_roles - else: - raise Exception('Unknown entity: {} ({})'.format(entity, type(entity))) + id = to_snowflake(entity) + return id in self.mentions or id in self.mention_roles @cached_property def without_mentions(self): diff --git a/disco/util/__init__.py b/disco/util/__init__.py index d0c3a37..d745c08 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -1,11 +1,12 @@ +import six import skema def to_snowflake(i): - if isinstance(i, long): + if isinstance(i, six.integer_types): return i elif isinstance(i, str): - return long(i) + return int(i) elif hasattr(i, 'id'): return i.id diff --git a/disco/util/json.py b/disco/util/json.py index 01267ff..5d92963 100644 --- a/disco/util/json.py +++ b/disco/util/json.py @@ -1,11 +1,11 @@ -from __future__ import absolute_import +from __future__ import absolute_import, print_function from json import dumps try: from rapidjson import loads except ImportError: - print '[WARNING] rapidjson not installed, falling back to default Python JSON parser' + print('[WARNING] rapidjson not installed, falling back to default Python JSON parser') from json import loads __all__ = ['dumps', 'loads'] diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 0708a57..36134e6 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -1,5 +1,7 @@ import gevent +import sys +from disco import VERSION from disco.cli import disco_main from disco.bot import Bot from disco.bot.plugin import Plugin @@ -12,6 +14,24 @@ class BasicPlugin(Plugin): event.message.author.username, event.message.content)) + @Plugin.command('status', '[component]') + def on_status_command(self, event, component=None): + if component == 'state': + parts = [] + parts.append('Guilds: {}'.format(len(self.state.guilds))) + parts.append('Channels: {}'.format(len(self.state.channels))) + parts.append('Users: {}'.format(len(self.state.users))) + + event.msg.reply('State Information: ```\n{}\n```'.format('\n'.join(parts))) + return + + event.msg.reply('Disco v{} running on Python {}.{}.{}'.format( + VERSION, + sys.version_info.major, + sys.version_info.minor, + sys.version_info.micro, + )) + @Plugin.command('echo', '') def on_test_command(self, event, content): event.msg.reply(content) @@ -38,7 +58,6 @@ class BasicPlugin(Plugin): pin_count = len(event.channel.get_pins()) msg_count = 0 - print event.channel.messages_iter(bulk=True) for msgs in event.channel.messages_iter(bulk=True): msg_count += len(msgs) diff --git a/requirements.txt b/requirements.txt index 99856b4..7c84120 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ enum34==1.1.6 Flask==0.11.1 gevent==1.1.2 greenlet==0.4.10 -# holster==0.0.7 +holster==1.0.0 idna==2.1 inflection==0.3.1 ipaddress==1.0.17