From e3140b6e8b7080818df9375c7664bae18a3da710 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 7 Oct 2016 21:51:51 -0500 Subject: [PATCH] More API stuff, user/role mention arguments --- disco/api/client.py | 16 +++++------ disco/api/http.py | 2 +- disco/bot/__init__.py | 4 +-- disco/bot/bot.py | 43 +++++++++++++++++++++--------- disco/bot/command.py | 36 +++++++++++++++++++++++-- disco/bot/parser.py | 16 +++++------ disco/bot/plugin.py | 30 +++++++++++++++++++++ disco/gateway/events.py | 5 ++-- disco/state.py | 4 --- disco/types/base.py | 32 ++++++++++++++++++---- disco/types/channel.py | 38 +++++++++++++++++++++++++++ disco/types/guild.py | 57 +++++++++++++++++++++++++++++++--------- disco/types/user.py | 8 ++++++ examples/basic_plugin.py | 13 +++++++++ requirements.txt | 2 +- 15 files changed, 249 insertions(+), 57 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index 761344a..1e2a18b 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -127,11 +127,11 @@ class APIClient(LoggingClass): def guilds_channels_list(self, guild): r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) - return Channel.create_map(self.client, r.json()) + return Channel.create_map(self.client, r.json(), guild_id=guild) def guilds_channels_create(self, guild, **kwargs): r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) - return Channel.create(self.client, r.json()) + return Channel.create(self.client, r.json(), guild_id=guild) def guilds_channels_modify(self, guild, channel, position): self.http(Routes.GUILDS_CHANNELS_MODIFY, dict(guild=guild), json={ @@ -141,11 +141,11 @@ class APIClient(LoggingClass): def guilds_members_list(self, guild): r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild)) - return GuildMember.create_map(self.client, r.json()) + return GuildMember.create_map(self.client, r.json(), guild_id=guild) def guilds_members_get(self, guild, member): r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member)) - return GuildMember.create(self.client, r.json()) + return GuildMember.create(self.client, r.json(), guild_id=guild) def guilds_members_modify(self, guild, member, **kwargs): self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) @@ -167,19 +167,19 @@ class APIClient(LoggingClass): def guilds_roles_list(self, guild): r = self.http(Routes.GUILDS_ROLES_LIST, dict(guild=guild)) - return Role.create_map(self.client, r.json()) + return Role.create_map(self.client, r.json(), guild_id=guild) def guilds_roles_create(self, guild): r = self.http(Routes.GUILDS_ROLES_CREATE, dict(guild=guild)) - return Role.create(self.client, r.json()) + return Role.create(self.client, r.json(), guild_id=guild) def guilds_roles_modify_batch(self, guild, roles): r = self.http(Routes.GUILDS_ROLES_MODIFY_BATCH, dict(guild=guild), json=roles) - return Role.create_map(self.client, r.json()) + return Role.create_map(self.client, r.json(), guild_id=guild) def guilds_roles_modify(self, guild, role, **kwargs): r = self.http(Routes.GUILDS_ROLES_MODIFY, dict(guild=guild, role=role), json=kwargs) - return Role.create(self.client, r.json()) + return Role.create(self.client, r.json(), guild_id=guild) def guilds_roles_delete(self, guild, role): self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role)) diff --git a/disco/api/http.py b/disco/api/http.py index f1baf95..6b06444 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -63,7 +63,7 @@ class Routes(object): GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') GUILDS_BANS_DELETE = (HTTPMethod.DELETE, GUILDS + '/bans/{user}') GUILDS_ROLES_LIST = (HTTPMethod.GET, GUILDS + '/roles') - GUILDS_ROLES_CREATE = (HTTPMethod.GET, GUILDS + '/roles') + GUILDS_ROLES_CREATE = (HTTPMethod.POST, GUILDS + '/roles') GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, GUILDS + '/roles') GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, GUILDS + '/roles/{role}') GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, GUILDS + '/roles/{role}') diff --git a/disco/bot/__init__.py b/disco/bot/__init__.py index 027ba35..1ea73fa 100644 --- a/disco/bot/__init__.py +++ b/disco/bot/__init__.py @@ -1,4 +1,4 @@ -from disco.bot.bot import Bot +from disco.bot.bot import Bot, BotConfig from disco.bot.plugin import Plugin -__all__ = ['Bot', 'Plugin'] +__all__ = ['Bot', 'BotConfig', 'Plugin'] diff --git a/disco/bot/bot.py b/disco/bot/bot.py index a28263e..c9b3c1a 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -3,9 +3,11 @@ import importlib import inspect from six.moves import reload_module +from holster.threadlocal import ThreadLocal from disco.bot.plugin import Plugin from disco.bot.command import CommandEvent +# from disco.bot.storage import Storage class BotConfig(object): @@ -39,7 +41,7 @@ class BotConfig(object): helpful for allowing edits to typod commands. plugin_config_provider : Optional[function] If set, this function will be called before loading a plugin, with the - plugins name. Its expected to return a type of configuration object the + plugins class. Its expected to return a type of configuration object the plugin understands. """ token = None @@ -84,6 +86,12 @@ class Bot(object): self.client = client self.config = config or BotConfig() + # The context carries information about events in a threadlocal storage + self.ctx = ThreadLocal() + + # The storage object acts as a dynamic contextual aware store + # self.storage = Storage(self.ctx) + if self.client.config.manhole_enable: self.client.manhole_locals['bot'] = self @@ -160,17 +168,28 @@ class Bot(object): content = msg.content if self.config.commands_require_mention: - match = any(( - self.config.commands_mention_rules['user'] and msg.is_mentioned(self.client.state.me), - self.config.commands_mention_rules['everyone'] and msg.mention_everyone, - self.config.commands_mention_rules['role'] and any(map(msg.is_mentioned, - msg.guild.get_member(self.client.state.me).roles - )))) - - if not match: + mention_direct = msg.is_mentioned(self.client.state.me) + mention_everyone = msg.mention_everyone + mention_roles = list(filter(lambda r: msg.is_mentioned(r), + msg.guild.get_member(self.client.state.me).roles)) + + if not any(( + self.config.commands_mention_rules['user'] and mention_direct, + self.config.commands_mention_rules['everyone'] and mention_everyone, + self.config.commands_mention_rules['role'] and any(mention_roles), + )): raise StopIteration - content = msg.without_mentions.strip() + if mention_direct: + content = content.replace(self.client.state.me.mention, '', 1) + content = content.replace(self.client.state.me.mention_nick, '', 1) + elif mention_everyone: + content = content.replace('@everyone', '', 1) + else: + for role in mention_roles: + content = content.replace(role.mention, '', 1) + + content = content.lstrip() if self.config.commands_prefix and not content.startswith(self.config.commands_prefix): raise StopIteration @@ -247,7 +266,7 @@ class Bot(object): raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) if not config and callable(self.config.plugin_config_provider): - config = self.config.plugin_config_provider(cls.__name__) + config = self.config.plugin_config_provider(cls) self.plugins[cls.__name__] = cls(self, config) self.plugins[cls.__name__].load() @@ -293,7 +312,7 @@ class Bot(object): mod = importlib.import_module(path) for entry in map(lambda i: getattr(mod, i), dir(mod)): - if inspect.isclass(entry) and issubclass(entry, Plugin): + if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin: self.add_plugin(entry, config) break else: diff --git a/disco/bot/command.py b/disco/bot/command.py index d3c541d..4fbd9e1 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -5,6 +5,7 @@ from disco.util.functional import cached_property REGEX_FMT = '({})' ARGS_REGEX = '( (.*)$|$)' +MENTION_RE = re.compile('<@!?([0-9]+)>') class CommandEvent(object): @@ -97,10 +98,41 @@ class Command(object): self.func = func self.triggers = [trigger] + (aliases or []) - self.args = ArgumentSet.from_string(args or '') + def resolve_role(ctx, id): + return ctx.msg.guild.roles.get(id) + + def resolve_user(ctx, id): + return ctx.msg.mentions.get(id) + + self.args = ArgumentSet.from_string(args or '', { + 'mention': self.mention_type([resolve_role, resolve_user]), + 'user': self.mention_type([resolve_user], force=True), + 'role': self.mention_type([resolve_role], force=True), + }) + self.group = group self.is_regex = is_regex + @staticmethod + def mention_type(getters, force=False): + def _f(ctx, i): + res = MENTION_RE.match(i) + if not res: + raise TypeError('Invalid mention: {}'.format(i)) + + id = int(res.group(1)) + + for getter in getters: + obj = getter(ctx, id) + if obj: + return obj + + if force: + raise TypeError('Cannot resolve mention: {}'.format(id)) + + return id + return _f + @cached_property def compiled_regex(self): """ @@ -137,7 +169,7 @@ class Command(object): )) try: - args = self.args.parse(event.args) + args = self.args.parse(event.args, ctx=event) except ArgumentError as e: raise CommandError(e.message) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 513ac52..0ae2bb2 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -7,10 +7,10 @@ PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])') # Mapping of types TYPE_MAP = { - 'str': str, - 'int': int, - 'float': float, - 'snowflake': int, + 'str': lambda ctx, data: str(data), + 'int': lambda ctx, data: int(data), + 'float': lambda ctx, data: int(data), + 'snowflake': lambda ctx, data: int(data), } @@ -105,7 +105,7 @@ class ArgumentSet(object): return args - def convert(self, types, value): + def convert(self, ctx, types, value): """ Attempts to convert a value to one or more types. @@ -122,7 +122,7 @@ class ArgumentSet(object): raise Exception('Unknown type {}'.format(typ_name)) try: - return typ(value) + return typ(ctx, value) except Exception as e: continue @@ -140,7 +140,7 @@ class ArgumentSet(object): self.args.append(arg) - def parse(self, rawargs): + def parse(self, rawargs, ctx=None): """ Parse a string of raw arguments into this argument specification. """ @@ -158,7 +158,7 @@ class ArgumentSet(object): if arg.types: for idx, r in enumerate(raw): try: - raw[idx] = self.convert(arg.types, r) + raw[idx] = self.convert(ctx, arg.types, r) except: raise ArgumentError('cannot convert `{}` to `{}`'.format( r, ', '.join(arg.types) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a9714f7..2362bbb 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,6 +1,7 @@ import inspect import functools import gevent +import os from holster.emitter import Priority @@ -122,6 +123,8 @@ class Plugin(LoggingClass, PluginDeco): self.bot = bot self.client = bot.client self.state = bot.client.state + self.ctx = bot.ctx + self.storage = bot.storage self.config = config def bind_all(self): @@ -149,11 +152,17 @@ class Plugin(LoggingClass, PluginDeco): """ Executes a CommandEvent this plugin owns """ + self.ctx['guild'] = event.guild + self.ctx['channel'] = event.channel + self.ctx['user'] = event.author + try: return event.command.execute(event) except CommandError as e: event.msg.reply(e.message) return False + finally: + self.ctx.drop() def register_trigger(self, typ, when, func): """ @@ -245,3 +254,24 @@ class Plugin(LoggingClass, PluginDeco): def reload(self): self.bot.reload_plugin(self.__class__) + + @staticmethod + def load_config_from_path(cls, path, format='json'): + inst = cls() + + if not os.path.exists(path): + return inst + + with open(path, 'r') as f: + data = f.read() + + if format == 'json': + import json + inst.__dict__.update(json.loads(data)) + elif format == 'yaml': + import yaml + inst.__dict__.update(yaml.load(data)) + else: + raise Exception('Unsupported config format {}'.format(format)) + + return inst diff --git a/disco/gateway/events.py b/disco/gateway/events.py index a58f62d..9e718eb 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -20,7 +20,7 @@ class GatewayEvent(Model): """ cls = globals().get(inflection.camelize(data['t'].lower())) if not cls: - raise Exception('Could not find cls for {}'.format(data['t'])) + raise Exception('Could not find cls for {} ({})'.format(data['t'], data)) return cls.create(data['d'], client) @@ -46,7 +46,8 @@ class GatewayEvent(Model): modname, _ = self._wraps_model if hasattr(self, modname) and hasattr(getattr(self, modname), name): return getattr(getattr(self, modname), name) - return object.__getattr__(self, name) + print self.__dict__ + raise AttributeError(name) def wraps_model(model, alias=None): diff --git a/disco/state.py b/disco/state.py index 1f7d998..08ed556 100644 --- a/disco/state.py +++ b/disco/state.py @@ -154,10 +154,6 @@ class State(object): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) - for channel in event.guild.channels.values(): - channel.guild_id = event.guild.id - channel.guild = event.guild - for member in event.guild.members.values(): self.users[member.user.id] = member.user diff --git a/disco/types/base.py b/disco/types/base.py index 53f7fe0..0f8899d 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -139,8 +139,15 @@ class ModelMeta(type): class Model(six.with_metaclass(ModelMeta)): - def __init__(self, obj, client=None): - self.client = client + def __init__(self, *args, **kwargs): + self.client = kwargs.pop('client', None) + + if len(args) == 1: + obj = args[0] + elif len(args) == 2: + obj, self.client = args + else: + obj = kwargs for name, field in self._fields.items(): if name not in obj or not obj[field.src_name]: @@ -148,7 +155,7 @@ class Model(six.with_metaclass(ModelMeta)): setattr(self, field.dst_name, field.default()) continue - value = field.try_convert(obj[field.src_name], client) + value = field.try_convert(obj[field.src_name], self.client) setattr(self, field.dst_name, value) def update(self, other): @@ -165,10 +172,25 @@ class Model(six.with_metaclass(ModelMeta)): except: pass + def to_dict(self): + return {k: getattr(self, k) for k in six.iterkeys(self._fields)} + @classmethod - def create(cls, client, data): - return cls(data, client) + def create(cls, client, data, **kwargs): + inst = cls(data, client) + inst.__dict__.update(kwargs) + return inst @classmethod def create_map(cls, client, data): return list(map(functools.partial(cls.create, client), data)) + + @classmethod + def attach(cls, it, data): + for item in it: + for k, v in data.items(): + try: + setattr(item, k, v) + except: + # TODO: wtf + pass diff --git a/disco/types/channel.py b/disco/types/channel.py index 0e5be65..6bc2330 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -44,6 +44,12 @@ class PermissionOverwrite(Model): allow = Field(PermissionValue) deny = Field(PermissionValue) + def save(self): + return self.channel.update_overwrite(self) + + def delete(self): + return self.channel.delete_overwrite(self) + class Channel(Model, Permissible): """ @@ -81,6 +87,11 @@ class Channel(Model, Permissible): type = Field(enum(ChannelType)) overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') + def __init__(self, *args, **kwargs): + super(Channel, self).__init__(*args, **kwargs) + + self.attach(self.overwrites.values(), {'channel_id': self.id, 'channel': self}) + def get_permissions(self, user): """ Get the permissions a user has in the channel @@ -203,6 +214,33 @@ class Channel(Model, Permissible): vc.connect(*args, **kwargs) return vc + def create_overwrite(self, entity, allow=0, deny=0): + from disco.types.guild import Role + + type = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER + ow = PermissionOverwrite( + id=entity.id, + type=type, + allow=allow, + deny=deny + ) + + ow.channel_id = self.id + ow.channel = self + + return self.update_overwrite(ow) + + def update_overwrite(self, ow): + self.client.api.channels_permissions_modify(self.id, + ow.id, + ow.allow.value if ow.allow else 0, + ow.deny.value if ow.deny else 0, + ow.type.name) + return ow + + def delete_overwrite(self, ow): + self.client.api.channels_permissions_delete(self.id, ow.id) + class MessageIterator(object): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index 10d93eb..5d8441b 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -71,6 +71,18 @@ class Role(Model): color = Field(int) permissions = Field(PermissionValue) position = Field(int) + mentionable = Field(bool) + + def save(self): + self.guild.update_role(self) + + @property + def mention(self): + return '<@{}>'.format(self.id) + + @cached_property + def guild(self): + return self.client.state.guilds.get(self.id) class GuildMember(Model): @@ -140,6 +152,10 @@ class GuildMember(Model): """ self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') + def add_role(self, role): + roles = self.roles + [role.id] + self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles) + @cached_property def guild(self): return self.client.state.guilds.get(self.guild_id) @@ -195,7 +211,6 @@ class Guild(Model, Permissible): voice_states : dict(str, :class:`disco.types.voice.VoiceState`) All of the guilds voice states. """ - id = Field(snowflake) owner_id = Field(snowflake) afk_channel_id = Field(snowflake) @@ -215,6 +230,15 @@ class Guild(Model, Permissible): emojis = Field(dictof(Emoji, key='id')) voice_states = Field(dictof(VoiceState, key='session_id')) + def __init__(self, *args, **kwargs): + super(Guild, self).__init__(*args, **kwargs) + + self.attach(self.channels.values(), {'guild_id': self.id}) + self.attach(self.members.values(), {'guild_id': self.id}) + self.attach(self.roles.values(), {'guild_id': self.id}) + self.attach(self.emojis.values(), {'guild_id': self.id}) + self.attach(self.voice_states.values(), {'guild_id': self.id}) + def get_permissions(self, user): """ Get the permissions a user has in this guild. @@ -270,14 +294,23 @@ class Guild(Model, Permissible): return self.members.get(user) - def validate_members(self, ctx): - if self.members: - for member in self.members.values(): - member.guild = self - member.guild_id = self.id - - def validate_channels(self, ctx): - if self.channels: - for channel in self.channels.values(): - channel.guild_id = self.id - channel.guild = self + def create_role(self): + """ + Create a new role. + + Returns + ------- + :class:`Role` + The newly created role. + """ + return self.client.api.guilds_roles_create(self.id) + + def update_role(self, role): + return self.client.api.guilds_roles_modify(self.id, role.id, **{ + 'name': role.name, + 'permissions': role.permissions.value, + 'position': role.position, + 'color': role.color, + 'hoist': role.hoist, + 'mentionable': role.mentionable, + }) diff --git a/disco/types/user.py b/disco/types/user.py index 160f983..7281796 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -9,6 +9,14 @@ class User(Model): verified = Field(bool) email = Field(str) + @property + def mention(self): + return '<@{}>'.format(self.id) + + @property + def mention_nick(self): + return '<@!{}>'.format(self.id) + def to_string(self): return '{}#{}'.format(self.username, self.discriminator) diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 5d8441f..6ed449f 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -94,3 +94,16 @@ class BasicPlugin(Plugin): event.msg.reply('```json\n{}\n```'.format( json.dumps(perms.to_dict(), sort_keys=True, indent=2, separators=(',', ': ')) )) + + """ + @Plugin.command('tag', ' [value:str]') + def on_tag(self, event, name, value=None): + if value: + self.storage.guild['tags'][name] = value + event.msg.reply(':ok_hand:') + else: + if name in self.storage.guild['tags']: + return event.msg.reply(self.storage.guild['tags'][name]) + else: + event.msg.reply('Unknown tag `{}`'.format(name)) + """ diff --git a/requirements.txt b/requirements.txt index e4d2570..a8586e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ gevent==1.1.2 -holster==1.0.4 +holster==1.0.5 inflection==0.3.1 requests==2.11.1 six==1.10.0