diff --git a/disco/api/client.py b/disco/api/client.py index 1e2a18b..a6cea02 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -1,3 +1,5 @@ +import six + from disco.api.http import Routes, HTTPClient from disco.util.logging import LoggingClass @@ -15,7 +17,7 @@ def optional(**kwargs): :returns: dict """ - return {k: v for k, v in kwargs.items() if v is not None} + return {k: v for k, v in six.iteritems(kwargs) if v is not None} class APIClient(LoggingClass): diff --git a/disco/api/http.py b/disco/api/http.py index 6b06444..99ca763 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -1,6 +1,7 @@ import requests import random import gevent +import six from holster.enum import Enum @@ -166,7 +167,7 @@ class HTTPClient(LoggingClass): kwargs['headers'] = self.headers # Build the bucket URL - filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in args.items()} + filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in six.iteritems(args)} bucket = (route[0].value, route[1].format(**filtered)) # Possibly wait if we're rate limited diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 4c6846f..215bdb6 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -1,7 +1,8 @@ import re import os -import importlib +import six import inspect +import importlib from six.moves import reload_module from holster.threadlocal import ThreadLocal @@ -48,6 +49,10 @@ class BotConfig(Config): commands_level_getter : function If set, a function which when given a GuildMember or User, returns the relevant :class:`disco.bot.commands.CommandLevels`. + commands_group_abbrev : function + If true, command groups may be abbreviated to the least common variation. + E.g. the grouping 'test' may be abbreviated down to 't', unless 'tag' exists, + in which case it may be abbreviated down to 'te'. plugin_config_provider : Optional[function] If set, this function will replace the default configuration loading function, which normally attempts to load a file located at config/plugin_name.fmt @@ -72,6 +77,7 @@ class BotConfig(Config): commands_prefix = '' commands_allow_edit = True commands_level_getter = None + commands_group_abbrev = True plugin_config_provider = None plugin_config_format = 'yaml' @@ -121,6 +127,7 @@ class Bot(object): self.client.manhole_locals['bot'] = self self.plugins = {} + self.group_abbrev = {} # Only bind event listeners if we're going to parse commands if self.config.commands_enabled: @@ -140,7 +147,7 @@ class Bot(object): self.add_plugin_module(plugin_mod) # Convert level mapping - for k, v in self.config.levels.items(): + for k, v in six.iteritems(self.config.levels): self.config.levels[k] = CommandLevels.get(v) @classmethod @@ -169,10 +176,37 @@ class Bot(object): """ Generator of all commands this bots plugins have defined """ - for plugin in self.plugins.values(): - for command in plugin.commands.values(): + for plugin in six.itervalues(self.plugins): + for command in six.itervalues(plugin.commands): yield command + def recompute(self): + """ + Called when a plugin is loaded/unloaded to recompute internal state. + """ + self.compute_group_abbrev() + if self.config.commands_group_abbrev: + self.compute_command_matches_re() + + def compute_group_abbrev(self): + """ + Computes all possible abbreviations for a command grouping + """ + self.group_abbrev = {} + groups = set(command.group for command in self.commands if command.group) + + for group in groups: + grp = group + while grp: + # If the group already exists, means someone else thought they + # could use it so we need to + if grp in list(six.itervalues(self.group_abbrev)): + self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp} + else: + self.group_abbrev[group] = grp + + grp = grp[:-1] + def compute_command_matches_re(self): """ Computes a single regex which matches all possible command combinations. @@ -340,7 +374,7 @@ class Bot(object): self.plugins[cls.__name__] = cls(self, config) self.plugins[cls.__name__].load() - self.compute_command_matches_re() + self.recompute() def rmv_plugin(self, cls): """ @@ -356,7 +390,7 @@ class Bot(object): self.plugins[cls.__name__].unload() del self.plugins[cls.__name__] - self.compute_command_matches_re() + self.recompute() def reload_plugin(self, cls): """ diff --git a/disco/bot/command.py b/disco/bot/command.py index 53dd672..505d74f 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -103,12 +103,14 @@ class Command(object): is_regex : Optional[bool] Whether the triggers for this command should be treated as raw regex. """ - def __init__(self, plugin, func, trigger, args=None, level=None, - aliases=None, group=None, is_regex=False): - + def __init__(self, plugin, func, trigger, *args, **kwargs): self.plugin = plugin self.func = func - self.triggers = [trigger] + (aliases or []) + self.triggers = [trigger] + self.update(*args, **kwargs) + + def update(self, args=None, level=None, aliases=None, group=None, is_regex=None): + self.triggers += aliases or [] def resolve_role(ctx, id): return ctx.msg.guild.roles.get(id) @@ -161,7 +163,12 @@ class Command(object): if self.is_regex: return REGEX_FMT.format('|'.join(self.triggers)) else: - group = self.group + ' ' if self.group else '' + group = '' + if self.group: + if self.group in self.plugin.bot.group_abbrev.get(self.group): + group = '{}(?:\w+)? '.format(self.group) + else: + group = self.group return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) def execute(self, event): diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index dd6c097..a8585f7 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,7 +1,8 @@ -import inspect -import functools +import six import gevent +import inspect import weakref +import functools from holster.emitter import Priority @@ -170,6 +171,7 @@ class Plugin(LoggingClass, PluginDeco): if meta['type'] == 'listener': self.register_listener(member, meta['what'], meta['desc'], meta['priority']) elif meta['type'] == 'command': + meta['kwargs']['update'] = True self.register_command(member, *meta['args'], **meta['kwargs']) elif meta['type'] == 'schedule': self.register_schedule(member, *meta['args'], **meta['kwargs']) @@ -260,8 +262,11 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - wrapped = functools.partial(self._dispatch, 'command', func) - self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs) + if kwargs.pop('update', False) and func.__name__ in self.commands: + self.commands[func.__name__].update(*args, **kwargs) + else: + wrapped = functools.partial(self._dispatch, 'command', func) + self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs) def register_schedule(self, func, interval, repeat=True, init=True): """ @@ -303,7 +308,7 @@ class Plugin(LoggingClass, PluginDeco): for listener in self.listeners: listener.remove() - for schedule in self.schedules.values(): + for schedule in six.itervalues(self.schedules): schedule.kill() def reload(self): diff --git a/disco/cli.py b/disco/cli.py index bb2e357..951fd96 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -5,6 +5,7 @@ creating and running bots/clients. from __future__ import print_function import os +import six import logging import argparse @@ -47,7 +48,7 @@ def disco_main(run=False): else: config = ClientConfig() - for k, v in vars(args).items(): + for k, v in six.iteritems(vars(args)): if hasattr(config, k) and v is not None: setattr(config, k, v) diff --git a/disco/state.py b/disco/state.py index cec1687..cdf5d81 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,3 +1,4 @@ +import six import inflection from collections import defaultdict, deque, namedtuple @@ -159,7 +160,7 @@ class State(object): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) - for member in event.guild.members.values(): + for member in six.itervalues(event.guild.members): self.users[member.user.id] = member.user # Request full member list diff --git a/disco/types/base.py b/disco/types/base.py index 83bcb3d..6a0ca05 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -178,7 +178,7 @@ class Model(six.with_metaclass(ModelMeta)): else: obj = kwargs - for name, field in self._fields.items(): + for name, field in six.iteritems(self._fields): if field.src_name not in obj or not obj[field.src_name]: if field.has_default(): setattr(self, field.dst_name, field.default()) @@ -217,7 +217,7 @@ class Model(six.with_metaclass(ModelMeta)): @classmethod def attach(cls, it, data): for item in it: - for k, v in data.items(): + for k, v in six.iteritems(data): try: setattr(item, k, v) except: diff --git a/disco/types/channel.py b/disco/types/channel.py index 7636047..0c12048 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,12 +1,12 @@ -from holster.enum import Enum +import six -from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text -from disco.types.permissions import PermissionValue +from holster.enum import Enum from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property, one_or_many, chunks from disco.types.user import User -from disco.types.permissions import Permissions, Permissible +from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text +from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.voice.client import VoiceClient @@ -91,7 +91,7 @@ class Channel(Model, Permissible): def __init__(self, *args, **kwargs): super(Channel, self).__init__(*args, **kwargs) - self.attach(self.overwrites.values(), {'channel_id': self.id, 'channel': self}) + self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) def get_permissions(self, user): """ @@ -108,7 +108,7 @@ class Channel(Model, Permissible): member = self.guild.members.get(user.id) base = self.guild.get_permissions(user) - for ow in self.overwrites.values(): + for ow in six.itervalues(self.overwrites): if ow.id != user.id and ow.id not in member.roles: continue diff --git a/disco/types/guild.py b/disco/types/guild.py index 70ff577..ea65f1d 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,3 +1,5 @@ +import six + from holster.enum import Enum from disco.api.http import APIException @@ -6,8 +8,8 @@ from disco.util.functional import cached_property from disco.types.base import Model, Field, snowflake, listof, dictof, datetime, text, binary, enum from disco.types.user import User from disco.types.voice import VoiceState -from disco.types.permissions import PermissionValue, Permissions, Permissible from disco.types.channel import Channel +from disco.types.permissions import PermissionValue, Permissions, Permissible VerificationLevel = Enum( @@ -243,11 +245,11 @@ class Guild(Model, Permissible): def __init__(self, *args, **kwargs): super(Guild, self).__init__(*args, **kwargs) - self.attach(self.channels.values(), {'guild_id': self.id}) - self.attach(self.members.values(), {'guild_id': self.id}) - self.attach(self.roles.values(), {'guild_id': self.id}) - self.attach(self.emojis.values(), {'guild_id': self.id}) - self.attach(self.voice_states.values(), {'guild_id': self.id}) + self.attach(six.itervalues(self.channels), {'guild_id': self.id}) + self.attach(six.itervalues(self.members), {'guild_id': self.id}) + self.attach(six.itervalues(self.roles), {'guild_id': self.id}) + self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) + self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) def get_permissions(self, user): """ @@ -281,7 +283,7 @@ class Guild(Model, Permissible): """ user = to_snowflake(user) - for state in self.voice_states.values(): + for state in six.itervalues(self.voice_states): if state.user_id == user: return state diff --git a/disco/util/websocket.py b/disco/util/websocket.py index c2bcf31..3871b13 100644 --- a/disco/util/websocket.py +++ b/disco/util/websocket.py @@ -24,7 +24,7 @@ class Websocket(LoggingClass, websocket.WebSocketApp): self.emitter = Emitter(gevent.spawn) # Hack to get events to emit - for var in self.__dict__.keys(): + for var in six.iterkeys(self.__dict__): if not var.startswith('on_'): continue