From 37d7b3bdefe9e0dd39cd976873b7af0fb4478a83 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 23 Sep 2016 21:50:18 -0500 Subject: [PATCH] Lots of API stuff, state additions, skema/modeling cleanup --- disco/api/client.py | 132 +++++++++++++++++++++++++++++++---- disco/api/http.py | 146 ++++++++++++++++++++++----------------- disco/bot/command.py | 12 ++++ disco/bot/parser.py | 7 +- disco/bot/plugin.py | 2 + disco/gateway/events.py | 24 ++++--- disco/state.py | 78 ++++++++++++++++++--- disco/types/base.py | 10 ++- disco/types/channel.py | 87 ++++++++++++++++++++++- disco/types/guild.py | 25 ++++--- disco/types/invite.py | 22 ++++++ disco/types/message.py | 22 +++--- disco/util/__init__.py | 36 +++++++--- disco/util/types.py | 20 +++++- examples/basic_plugin.py | 30 ++++++++ 15 files changed, 513 insertions(+), 140 deletions(-) create mode 100644 disco/types/invite.py diff --git a/disco/api/client.py b/disco/api/client.py index 77e3eb7..eb6527a 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -1,12 +1,15 @@ from disco.api.http import Routes, HTTPClient from disco.util.logging import LoggingClass +from disco.types.user import User from disco.types.message import Message +from disco.types.guild import Guild, GuildMember, Role from disco.types.channel import Channel +from disco.types.invite import Invite def optional(**kwargs): - return {k: v for k, v in kwargs if v is not None} + return {k: v for k, v in kwargs.items() if v is not None} class APIClient(LoggingClass): @@ -21,34 +24,33 @@ class APIClient(LoggingClass): return data['url'] + '?v={}&encoding={}'.format(version, encoding) def channels_get(self, channel): - r = self.http(Routes.CHANNELS_GET, channel) + r = self.http(Routes.CHANNELS_GET, dict(channel=channel)) return Channel.create(self.client, r.json()) def channels_modify(self, channel, **kwargs): - r = self.http(Routes.CHANNELS_MODIFY, channel, json=kwargs) + r = self.http(Routes.CHANNELS_MODIFY, dict(channel=channel), json=kwargs) return Channel.create(self.client, r.json()) def channels_delete(self, channel): - r = self.http(Routes.CHANNELS_DELETE, channel) + r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel)) return Channel.create(self.client, r.json()) def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50): - r = self.http(Routes.CHANNELS_MESSAGES_LIST, channel, json=optional( - channel=channel, + r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional( around=around, before=before, after=after, limit=limit )) - return [Message.create(self.client, i) for i in r.json()] + return Message.create_map(self.client, r.json()) def channels_messages_get(self, channel, message): - r = self.http(Routes.CHANNELS_MESSAGES_GET, channel, message) + r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) return Message.create(self.client, r.json()) def channels_messages_create(self, channel, content, nonce=None, tts=False): - r = self.http(Routes.CHANNELS_MESSAGES_CREATE, channel, json={ + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={ 'content': content, 'nonce': nonce, 'tts': tts, @@ -57,11 +59,117 @@ class APIClient(LoggingClass): return Message.create(self.client, r.json()) def channels_messages_modify(self, channel, message, content): - r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, channel, message, json={'content': content}) + r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, + dict(channel=channel, message=message), + json={'content': content}) return Message.create(self.client, r.json()) def channels_messages_delete(self, channel, message): - self.http(Routes.CHANNELS_MESSAGES_DELETE, channel, message) + self.http(Routes.CHANNELS_MESSAGES_DELETE, dict(channel=channel, message=message)) def channels_messages_delete_bulk(self, channel, messages): - self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, channel, json={'messages': messages}) + self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages}) + + def channels_permissions_modify(self, channel, permission, allow, deny, typ): + self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={ + 'allow': allow, + 'deny': deny, + 'type': typ, + }) + + def channels_permissions_delete(self, channel, permission): + self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission)) + + def channels_invites_list(self, channel): + r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel)) + return Invite.create_map(self.client, r.json()) + + def channels_invites_create(self, channel, max_age=86400, max_uses=0, temporary=False, unique=False): + r = self.http(Routes.CHANNELS_INVITES_CREATE, dict(channel=channel), json={ + 'max_age': max_age, + 'max_uses': max_uses, + 'temporary': temporary, + 'unique': unique + }) + return Invite.create(self.client, r.json()) + + def channels_pins_list(self, channel): + r = self.http(Routes.CHANNELS_PINS_LIST, dict(channel=channel)) + return Message.create_map(self.client, r.json()) + + def channels_pins_create(self, channel, message): + self.http(Routes.CHANNELS_PINS_CREATE, dict(channel=channel, message=message)) + + def channels_pins_delete(self, channel, message): + self.http(Routes.CHANNELS_PINS_DELETE, dict(channel=channel, message=message)) + + def guilds_get(self, guild): + r = self.http(Routes.GUILDS_GET, dict(guild=guild)) + return Guild.create(self.client, r.json()) + + def guilds_modify(self, guild, **kwargs): + r = self.http(Routes.GUILDS_MODIFY, dict(guild=guild), json=kwargs) + return Guild.create(self.client, r.json()) + + def guilds_delete(self, guild): + r = self.http(Routes.GUILDS_DELETE, dict(guild=guild)) + return Guild.create(self.client, r.json()) + + def guilds_channels_list(self, guild): + r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) + return Channel.create_map(self.client, r.json()) + + def guilds_channels_create(self, guild, **kwargs): + r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) + return Channel.create(self.client, r.json()) + + def guilds_channels_modify(self, guild, channel, position): + self.http(Routes.GUILDS_CHANNELS_MODIFY, dict(guild=guild), json={ + 'id': channel, + 'position': position, + }) + + def guilds_members_list(self, guild): + r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild)) + return GuildMember.create_map(self.client, r.json()) + + def guilds_members_get(self, guild, member): + r = self.http(Routes.GUILD_MEMBERS_GET, dict(guild=guild, member=member)) + return GuildMember.create(self.client, r.json()) + + def guilds_members_modify(self, guild, member, **kwargs): + self.http(Routes.GUILD_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) + + def guilds_members_kick(self, guild, member): + self.http(Routes.GUILD_MEMBERS_KICK, dict(guild=guild, member=member)) + + def guilds_bans_list(self, guild): + r = self.http(Routes.GUILD_BANS_LIST, dict(guild=guild)) + return User.create_map(self.client, r.json()) + + def guilds_bans_create(self, guild, user, delete_message_days): + self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ + 'delete-message-days': delete_message_days, + }) + + def guilds_bans_delete(self, guild, user): + self.http(Routes.GUILDS_BANS_DELETE, dict(guild=guild, user=user)) + + def guilds_roles_list(self, guild): + r = self.http(Routes.GUILDS_ROLES_LIST, dict(guild=guild)) + return Role.create_map(self.client, r.json()) + + def guilds_roles_create(self, guild): + r = self.http(Routes.GUILDS_ROLES_CREATE, dict(guild=guild)) + return Role.create(self.client, r.json()) + + def guilds_roles_modify_batch(self, guild, roles): + r = self.http(Routes.GUILDS_ROLES_MODIFY_BATCH, dict(guild=guild), json=roles) + return Role.create_map(self.client, r.json()) + + def guilds_roles_modify(self, guild, role, **kwargs): + r = self.http(Routes.GUILDS_ROLES_MODIFY, dict(guild=guild, role=role), json=kwargs) + return Role.create(self.client, r.json()) + + def guilds_roles_delete(self, guild, role): + self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role)) diff --git a/disco/api/http.py b/disco/api/http.py index bd5df9d..28cb892 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -1,4 +1,6 @@ import requests +import random +import gevent from holster.enum import Enum @@ -19,66 +21,69 @@ class Routes(object): GATEWAY_GET = (HTTPMethod.GET, '/gateway') # Channels - CHANNELS_GET = (HTTPMethod.GET, '/channels/{}') - CHANNELS_MODIFY= (HTTPMethod.PATCH, '/channels/{}') - CHANNELS_DELETE = (HTTPMethod.DELETE, '/channels/{}') - - CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, '/channels/{}/messages') - CHANNELS_MESSAGES_GET = (HTTPMethod.GET, '/channels/{}/messages/{}') - CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, '/channels/{}/messages') - CHANNELS_MESSAGES_MODFIY = (HTTPMethod.PATCH, '/channels/{}/messages/{}') - CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, '/channels/{}/messages/{}') - CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, '/channels/{}/messages/bulk_delete') - - CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, '/channels/{}/permissions/{}') - CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, '/channels/{}/permissions/{}') - CHANNELS_INVITES_LIST = (HTTPMethod.GET, '/channels/{}/invites') - CHANNELS_INVITES_CREATE = (HTTPMethod.POST, '/channels/{}/invites') - - CHANNELS_PINS_LIST = (HTTPMethod.GET, '/channels/{}/pins') - CHANNELS_PINS_CREATE = (HTTPMethod.PUT, '/channels/{}/pins/{}') - CHANNELS_PINS_DELETE = (HTTPMethod.DELETE, '/channels/{}/pins/{}') + CHANNELS = '/channels/{channel}' + CHANNELS_GET = (HTTPMethod.GET, CHANNELS) + CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS) + CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS) + + CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages') + CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}') + CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages') + CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}') + CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}') + CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete') + + CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}') + CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}') + CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites') + CHANNELS_INVITES_CREATE = (HTTPMethod.POST, CHANNELS + '/invites') + + CHANNELS_PINS_LIST = (HTTPMethod.GET, CHANNELS + '/pins') + CHANNELS_PINS_CREATE = (HTTPMethod.PUT, CHANNELS + '/pins/{pin}') + CHANNELS_PINS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/pins/{pin}') # Guilds - GUILDS_GET = (HTTPMethod.GET, '/guilds/{}') - GUILDS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}') - GUILDS_DELETE = (HTTPMethod.DELETE, '/guilds/{}') - GUILDS_CHANNELS_LIST = (HTTPMethod.GET, '/guilds/{}/channels') - GUILDS_CHANNELS_CREATE = (HTTPMethod.POST, '/guilds/{}/channels') - GUILDS_CHANNELS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/channels') - GUILDS_MEMBERS_LIST = (HTTPMethod.GET, '/guilds/{}/members') - GUILDS_MEMBERS_GET = (HTTPMethod.GET, '/guilds/{}/members/{}') - GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/members/{}') - GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, '/guilds/{}/members/{}') - GUILDS_BANS_LIST = (HTTPMethod.GET, '/guilds/{}/bans') - GUILDS_BANS_CREATE = (HTTPMethod.PUT, '/guilds/{}/bans/{}') - GUILDS_BANS_DELETE = (HTTPMethod.DELETE, '/guilds/{}/bans/{}') - GUILDS_ROLES_LIST = (HTTPMethod.GET, '/guilds/{}/roles') - GUILDS_ROLES_CREATE = (HTTPMethod.GET, '/guilds/{}/roles') - GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, '/guilds/{}/roles') - GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/roles/{}') - GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, '/guilds/{}/roles/{}') - GUILDS_PRUNE_COUNT = (HTTPMethod.GET, '/guilds/{}/prune') - GUILDS_PRUNE_BEGIN = (HTTPMethod.POST, '/guilds/{}/prune') - GUILDS_VOICE_REGIONS_LIST = (HTTPMethod.GET, '/guilds/{}/regions') - GUILDS_INVITES_LIST = (HTTPMethod.GET, '/guilds/{}/invites') - GUILDS_INTEGRATIONS_LIST = (HTTPMethod.GET, '/guilds/{}/integrations') - GUILDS_INTEGRATIONS_CREATE = (HTTPMethod.POST, '/guilds/{}/integrations') - GUILDS_INTEGRATIONS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/integrations/{}') - GUILDS_INTEGRATIONS_DELETE = (HTTPMethod.DELETE, '/guilds/{}/integrations/{}') - GUILDS_INTEGRATIONS_SYNC = (HTTPMethod.POST, '/guilds/{}/integrations/{}/sync') - GUILDS_EMBED_GET = (HTTPMethod.GET, '/guilds/{}/embed') - GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/embed') + GUILDS = '/guilds/{guild}' + GUILDS_GET = (HTTPMethod.GET, GUILDS) + GUILDS_MODIFY = (HTTPMethod.PATCH, GUILDS) + GUILDS_DELETE = (HTTPMethod.DELETE, GUILDS) + GUILDS_CHANNELS_LIST = (HTTPMethod.GET, GUILDS + '/channels') + GUILDS_CHANNELS_CREATE = (HTTPMethod.POST, GUILDS + '/channels') + GUILDS_CHANNELS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/channels') + GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') + GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') + GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') + GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') + GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') + GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') + GUILDS_BANS_DELETE = (HTTPMethod.DELETE, GUILDS + '/bans/{user}') + GUILDS_ROLES_LIST = (HTTPMethod.GET, GUILDS + '/roles') + GUILDS_ROLES_CREATE = (HTTPMethod.GET, GUILDS + '/roles') + GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, GUILDS + '/roles') + GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, GUILDS + '/roles/{role}') + GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, GUILDS + '/roles/{role}') + GUILDS_PRUNE_COUNT = (HTTPMethod.GET, GUILDS + '/prune') + GUILDS_PRUNE_BEGIN = (HTTPMethod.POST, GUILDS + '/prune') + GUILDS_VOICE_REGIONS_LIST = (HTTPMethod.GET, GUILDS + '/regions') + GUILDS_INVITES_LIST = (HTTPMethod.GET, GUILDS + '/invites') + GUILDS_INTEGRATIONS_LIST = (HTTPMethod.GET, GUILDS + '/integrations') + GUILDS_INTEGRATIONS_CREATE = (HTTPMethod.POST, GUILDS + '/integrations') + GUILDS_INTEGRATIONS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/integrations/{integration}') + GUILDS_INTEGRATIONS_DELETE = (HTTPMethod.DELETE, GUILDS + '/integrations/{integration}') + GUILDS_INTEGRATIONS_SYNC = (HTTPMethod.POST, GUILDS + '/integrations/{integration}/sync') + GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed') + GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed') # Users - USERS_ME_GET = (HTTPMethod.GET, '/users/@me') - USERS_ME_PATCH = (HTTPMethod.PATCH, '/users/@me') - USERS_ME_GUILDS_LIST = (HTTPMethod.GET, '/users/@me/guilds') - USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, '/users/@me/guilds/{}') - USERS_ME_DMS_LIST = (HTTPMethod.GET, '/users/@me/channels') - USERS_ME_DMS_CREATE = (HTTPMethod.POST, '/users/@me/channels') - USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, '/users/@me/connections') - USERS_GET = (HTTPMethod.GET, '/users/{}') + USERS = '/users' + USERS_ME_GET = (HTTPMethod.GET, USERS + '/@me') + USERS_ME_PATCH = (HTTPMethod.PATCH, USERS + '/@me') + USERS_ME_GUILDS_LIST = (HTTPMethod.GET, USERS + '/@me/guilds') + USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}') + USERS_ME_DMS_LIST = (HTTPMethod.GET, USERS + '/@me/channels') + USERS_ME_DMS_CREATE = (HTTPMethod.POST, USERS + '/@me/channels') + USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, USERS + '/@me/connections') + USERS_GET = (HTTPMethod.GET, USERS + '/{user}') class APIException(Exception): @@ -89,7 +94,7 @@ class APIException(Exception): class HTTPClient(LoggingClass): - BASE_URL = 'https://discordapp.com/api' + BASE_URL = 'https://discordapp.com/api/v6' MAX_RETRIES = 5 def __init__(self, token): @@ -100,7 +105,8 @@ class HTTPClient(LoggingClass): 'Authorization': 'Bot ' + token, } - def __call__(self, route, *args, **kwargs): + def __call__(self, route, args=None, **kwargs): + args = args or {} retry = kwargs.pop('retry_number', 0) # Merge or set headers @@ -109,17 +115,20 @@ class HTTPClient(LoggingClass): else: kwargs['headers'] = self.headers - # Compile URL args - compiled = (str(route[0]), (self.BASE_URL) + route[1].format(*args)) + # Build the bucket URL + filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in args.items()} + bucket = (route[0].value, route[1].format(**filtered)) # Possibly wait if we're rate limited - self.limiter.check(compiled) + self.limiter.check(bucket) # Make the actual request - r = requests.request(compiled[0], compiled[1], **kwargs) + url = self.BASE_URL + route[1].format(**args) + print route[0].value, url, kwargs + r = requests.request(route[0].value, url, **kwargs) # Update rate limiter - self.limiter.update(compiled, r) + self.limiter.update(bucket, r) # If we got a success status code, just return the data if r.status_code < 400: @@ -134,5 +143,14 @@ class HTTPClient(LoggingClass): self.log.error('Failing request, hit max retries') raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) + backoff = self.random_backoff() + self.log.warning('Request to `{}` failed with code {}, retrying after {}s'.format(url, r.status_code, backoff)) + gevent.sleep(backoff) + # Otherwise just recurse and try again - return self(route, retry_number=retry, *args, **kwargs) + return self(route, args, retry_number=retry, **kwargs) + + @staticmethod + def random_backoff(): + # 500 milliseconds to 5 seconds) + return random.randint(500, 5000) / 1000.0 diff --git a/disco/bot/command.py b/disco/bot/command.py index 4648e05..87fb4f9 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -15,6 +15,18 @@ class CommandEvent(object): self.name = self.match.group(1) self.args = self.match.group(2).strip().split(' ') + @property + def channel(self): + return self.msg.channel + + @property + def guild(self): + return self.msg.guild + + @property + def actor(self): + return self.msg.author + class CommandError(Exception): pass diff --git a/disco/bot/parser.py b/disco/bot/parser.py index fb203cf..b90fdfe 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -77,7 +77,10 @@ class ArgumentSet(object): if not arg.required and index + arg.true_count <= len(rawargs): continue - raw = rawargs[index:index + arg.true_count] + if arg.count == 0: + raw = rawargs[index:] + else: + raw = rawargs[index:index + arg.true_count] if arg.types: for idx, r in enumerate(raw): @@ -88,7 +91,7 @@ class ArgumentSet(object): r, ', '.join(arg.types) )) - if arg.true_count == 1: + if arg.count == 1: raw = raw[0] if not arg.types or arg.types == ['str'] and isinstance(raw, list): diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a711df1..bac5145 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -61,6 +61,8 @@ class Plugin(LoggingClass, PluginDeco): def __init__(self, bot, config): super(Plugin, self).__init__() self.bot = bot + self.client = bot.client + self.state = bot.client.state self.config = config self.listeners = [] diff --git a/disco/gateway/events.py b/disco/gateway/events.py index e15bc0f..51b3835 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -1,8 +1,7 @@ import inflection import skema -from disco.util import recursive_find_matching -from disco.types.base import BaseType +from disco.util import skema_find_recursive_by_type from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState @@ -15,8 +14,7 @@ class GatewayEvent(skema.Model): obj = cls.create(obj['d']) - # TODO: use skema info - for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)): + for item in skema_find_recursive_by_type(obj, skema.ModelType): item.client = client return obj @@ -68,13 +66,17 @@ class GuildDelete(GatewayEvent): class ChannelCreate(Sub('channel')): channel = skema.ModelType(Channel) + @property + def guild(self): + return self.channel.guild -class ChannelUpdate(Sub('channel')): - channel = skema.ModelType(Channel) +class ChannelUpdate(ChannelCreate): + pass -class ChannelDelete(Sub('channel')): - channel = skema.ModelType(Channel) + +class ChannelDelete(ChannelCreate): + pass class ChannelPinsUpdate(GatewayEvent): @@ -136,8 +138,12 @@ class GuildRoleDelete(GatewayEvent): class MessageCreate(Sub('message')): message = skema.ModelType(Message) + @property + def channel(self): + return self.message.channel + -class MessageUpdate(Sub('message')): +class MessageUpdate(MessageCreate): message = skema.ModelType(Message) diff --git a/disco/state.py b/disco/state.py index d25e164..750853f 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,16 +1,36 @@ +from collections import defaultdict, deque, namedtuple +from weakref import WeakValueDictionary + + +StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id']) + + +class StateConfig(object): + # Whether to keep a buffer of messages + track_messages = True + + # The number maximum number of messages to store + track_messages_size = 100 class State(object): - def __init__(self, client): + def __init__(self, client, config=None): self.client = client + self.config = config or StateConfig() self.me = None - self.channels = {} + self.dms = {} self.guilds = {} + self.channels = WeakValueDictionary() self.client.events.on('Ready', self.on_ready) + self.messages_stack = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) + if self.config.track_messages: + self.client.events.on('MessageCreate', self.on_message_create) + self.client.events.on('MessageDelete', self.on_message_delete) + # Guilds self.client.events.on('GuildCreate', self.on_guild_create) self.client.events.on('GuildUpdate', self.on_guild_update) @@ -24,27 +44,63 @@ class State(object): def on_ready(self, event): self.me = event.user + def on_message_create(self, event): + self.messages_stack[event.message.channel_id].append( + StackMessage(event.message.id, event.message.channel_id, event.message.author.id)) + + def on_message_update(self, event): + message, cid = event.message, event.message.channel_id + if cid not in self.messages_stack: + return + + sm = next((i for i in self.messages_stack[cid] if i.id == message.id), None) + if not sm: + return + + sm.id = message.id + sm.channel_id = cid + sm.author_id = message.author.id + + def on_message_delete(self, event): + if event.channel_id not in self.messages_stack: + return + + sm = next((i for i in self.messages_stack[event.channel_id] if i.id == event.id), None) + if not sm: + return + + self.messages_stack[event.channel_id].remove(sm) + def on_guild_create(self, event): self.guilds[event.guild.id] = event.guild - - for channel in event.guild.channels: - self.channels[channel.id] = channel + self.channels.update(event.guild.channels) def on_guild_update(self, event): self.guilds[event.guild.id] = event.guild def on_guild_delete(self, event): if event.guild_id in self.guilds: + # Just delete the guild, channel references will fall del self.guilds[event.guild_id] - # CHANNELS? - def on_channel_create(self, event): - self.channels[event.channel.id] = event.channel + if event.channel.is_guild and event.channel.guild_id in self.guilds: + self.guilds[event.channel.guild_id].channels[event.channel.id] = event.channel + self.channels[event.channel.id] = event.channel + elif event.channel.is_dm: + self.dms[event.channel.id] = event.channel + self.channels[event.channel.id] = event.channel def on_channel_update(self, event): - self.channels[event.channel.id] = event.channel + if event.channel.is_guild and event.channel.guild_id in self.guilds: + self.guilds[event.channel.id] = event.channel + self.channels[event.channel.id] = event.channel + elif event.channel.is_dm: + self.dms[event.channel.id] = event.channel + self.channels[event.channel.id] = event.channel def on_channel_delete(self, event): - if event.channel.id in self.channels: - del self.channels[event.channel.id] + if event.channel.is_guild and event.channel.guild_id in self.guilds: + del self.guilds[event.channel.id] + elif event.channel.is_dm: + del self.pms[event.channel.id] diff --git a/disco/types/base.py b/disco/types/base.py index 78bad1e..c6fb366 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -1,6 +1,7 @@ import skema +import functools -from disco.util import recursive_find_matching +from disco.util import skema_find_recursive_by_type class BaseType(skema.Model): @@ -11,9 +12,12 @@ class BaseType(skema.Model): # Valdiate obj.validate() - # TODO: this can be smarter using skema metadata - for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)): + for item in skema_find_recursive_by_type(obj, skema.ModelType): item.client = client obj.client = client return obj + + @classmethod + def create_map(cls, client, data): + return map(functools.partial(cls.create, client), data) diff --git a/disco/types/channel.py b/disco/types/channel.py index b404f2b..2e05e3e 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -3,6 +3,7 @@ import skema from holster.enum import Enum from disco.util.cache import cached_property +from disco.util.types import ListToDictType from disco.types.base import BaseType from disco.types.user import User @@ -34,18 +35,100 @@ class Channel(BaseType): name = skema.StringType() topic = skema.StringType() - last_message_id = skema.SnowflakeType() + _last_message_id = skema.SnowflakeType(stored_name='last_message_id') position = skema.IntType() bitrate = skema.IntType(required=False) recipient = skema.ModelType(User, required=False) type = skema.IntType(choices=ChannelType.ALL_VALUES) - permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite)) + overwrites = ListToDictType('id', skema.ModelType(PermissionOverwrite), stored_name='permission_overwrites') + + @property + def is_guild(self): + return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE) + + @property + def is_dm(self): + return self.type in (ChannelType.DM, ChannelType.GROUP_DM) + + @property + def last_message_id(self): + if self.id not in self.client.state.messages_stack: + return self._last_message_id + return self.client.state.messages_stack[self.id][-1].id + + @property + def messages(self): + return self.messages_iter() + + def messages_iter(self, **kwargs): + return MessageIterator(self.client, self.id, before=self.last_message_id, **kwargs) @cached_property def guild(self): return self.client.state.guilds.get(self.guild_id) + def get_invites(self): + return self.client.api.channels_invites_list(self.id) + + def get_pins(self): + return self.client.api.channels_pins_list(self.id) + def send_message(self, content, nonce=None, tts=False): return self.client.api.channels_messages_create(self.id, content, nonce, tts) + + +class MessageIterator(object): + Direction = Enum('UP', 'DOWN') + + def __init__(self, client, channel, direction=Direction.UP, bulk=False, before=None, after=None, chunk_size=100): + self.client = client + self.channel = channel + self.direction = direction + self.bulk = bulk + self.before = before + self.after = after + self.chunk_size = chunk_size + + self.last = None + self._buffer = [] + + if len(filter(bool, (before, after))) > 1: + raise Exception('Must specify at most one of before or after') + + if not any((before, after)) and self.direction == self.Direction.DOWN: + raise Exception('Must specify either before or after for downward seeking') + + def fill(self): + self._buffer = self.client.api.channels_messages_list( + self.channel, + before=self.before, + after=self.after, + limit=self.chunk_size) + + if not len(self._buffer): + raise StopIteration + + self.after = None + self.before = None + + if self.direction == self.Direction.UP: + self.before = self._buffer[-1].id + else: + self._buffer.reverse() + self.after == self._buffer[-1].id + + def __iter__(self): + return self + + def next(self): + if not len(self._buffer): + self.fill() + + if self.bulk: + res = self._buffer + self._buffer = [] + return res + else: + return self._buffer.pop() diff --git a/disco/types/guild.py b/disco/types/guild.py index fa984c4..09c6c55 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,8 +1,7 @@ import skema -from disco.util.cache import cached_property from disco.types.base import BaseType -from disco.util.types import PreHookType +from disco.util.types import PreHookType, ListToDictType from disco.types.user import User from disco.types.voice import VoiceState from disco.types.channel import Channel @@ -33,6 +32,10 @@ class GuildMember(BaseType): joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) roles = skema.ListType(skema.SnowflakeType()) + @property + def id(self): + return self.user.id + class Guild(BaseType): id = skema.SnowflakeType() @@ -53,20 +56,16 @@ class Guild(BaseType): features = skema.ListType(skema.StringType()) - members = skema.ListType(skema.ModelType(GuildMember)) - voice_states = skema.ListType(skema.ModelType(VoiceState)) - channels = skema.ListType(skema.ModelType(Channel)) - roles = skema.ListType(skema.ModelType(Role)) - emojis = skema.ListType(skema.ModelType(Emoji)) - - @cached_property - def members_dict(self): - return {i.user.id: i for i in self.members} + members = ListToDictType('id', skema.ModelType(GuildMember)) + channels = ListToDictType('id', skema.ModelType(Channel)) + roles = ListToDictType('id', skema.ModelType(Role)) + emojis = ListToDictType('id', skema.ModelType(Emoji)) + voice_states = ListToDictType('id', skema.ModelType(VoiceState)) def get_member(self, user): - return self.members_dict.get(user.id) + return self.members.get(user.id) def validate_channels(self, ctx): if self.channels: - for channel in self.channels: + for channel in self.channels.values(): channel.guild_id = self.id diff --git a/disco/types/invite.py b/disco/types/invite.py new file mode 100644 index 0000000..45f08f8 --- /dev/null +++ b/disco/types/invite.py @@ -0,0 +1,22 @@ +import skema + +from disco.util.types import PreHookType +from disco.types.base import BaseType +from disco.types.user import User +from disco.types.guild import Guild +from disco.types.channel import Channel + + +class Invite(BaseType): + code = skema.StringType() + + inviter = skema.ModelType(User) + guild = skema.ModelType(Guild) + channel = skema.ModelType(Channel) + + max_age = skema.IntType() + max_uses = skema.IntType() + uses = skema.IntType() + temporary = skema.BooleanType() + + created_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) diff --git a/disco/types/message.py b/disco/types/message.py index 21c2490..fdd4ece 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -2,7 +2,7 @@ import re import skema from disco.util.cache import cached_property -from disco.util.types import PreHookType +from disco.util.types import PreHookType, ListToDictType from disco.types.base import BaseType from disco.types.user import User from disco.types.guild import Role @@ -41,11 +41,11 @@ class Message(BaseType): pinned = skema.BooleanType(required=False) - mentions = skema.ListType(skema.ModelType(User)) + mentions = ListToDictType('id', skema.ModelType(User)) mention_roles = skema.ListType(skema.SnowflakeType()) embeds = skema.ListType(skema.ModelType(MessageEmbed)) - attachment = skema.ListType(skema.ModelType(MessageAttachment)) + attachments = ListToDictType('id', skema.ModelType(MessageAttachment)) @cached_property def guild(self): @@ -55,14 +55,6 @@ class Message(BaseType): def channel(self): return self.client.state.channels.get(self.channel_id) - @cached_property - def mention_users(self): - return [i.id for i in self.mentions] - - @cached_property - def mention_users_dict(self): - return {i.id: i for i in self.mentions} - def reply(self, *args, **kwargs): return self.channel.send_message(*args, **kwargs) @@ -74,11 +66,13 @@ class Message(BaseType): def is_mentioned(self, entity): if isinstance(entity, User): - return entity.id in self.mention_users + return entity.id in self.mentions elif isinstance(entity, Role): return entity.id in self.mention_roles + elif isinstance(entity, long): + return entity in self.mentions or entity in self.mention_roles else: - raise Exception('Unknown entity: {}'.format(entity)) + raise Exception('Unknown entity: {} ({})'.format(entity, type(entity))) @cached_property def without_mentions(self): @@ -95,6 +89,6 @@ class Message(BaseType): if id in self.mention_roles: return role_replace(id) else: - return user_replace(self.mention_users_dict.get(id)) + return user_replace(self.mentions.get(id)) return re.sub('<@!?([0-9]+)>', replace, self.content) diff --git a/disco/util/__init__.py b/disco/util/__init__.py index bf2ff3c..1b5f31e 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -1,18 +1,36 @@ +import skema -def recursive_find_matching(base, match_clause): +def _recurse(typ, field, value): result = [] - if hasattr(base, '__dict__'): - values = base.__dict__.values() - else: - values = list(base) + if isinstance(field, skema.ModelType): + result += skema_find_recursive_by_type(value, typ) - for v in values: - if match_clause(v): + if isinstance(field, (skema.ListType, skema.SetType, skema.DictType)): + if isinstance(field, skema.DictType): + value = value.values() + + for item in value: + if isinstance(field.field, typ): + result.append(item) + result += _recurse(typ, field.field, item) + + return result + + +def skema_find_recursive_by_type(base, typ): + result = [] + + for name, field in base._fields_by_stored_name.items(): + v = getattr(base, name, None) + + if not v: + continue + + if isinstance(field, typ): result.append(v) - if hasattr(v, '__dict__') or hasattr(v, '__iter__'): - result += recursive_find_matching(v, match_clause) + result += _recurse(typ, field, v) return result diff --git a/disco/util/types.py b/disco/util/types.py index d8b6d0d..46e4d39 100644 --- a/disco/util/types.py +++ b/disco/util/types.py @@ -1,4 +1,4 @@ -from skema import BaseType +from skema import BaseType, DictType class PreHookType(BaseType): @@ -16,3 +16,21 @@ class PreHookType(BaseType): def to_storage(self, *args, **kwargs): return self.field.to_storage(*args, **kwargs) + + +class ListToDictType(DictType): + def __init__(self, key, *args, **kwargs): + super(ListToDictType, self).__init__(*args, **kwargs) + self.key = key + + def to_python(self, value): + if not value: + return {} + + to_python = self.field.to_python + + obj = {} + for item in value: + item = to_python(item) + obj[getattr(item, self.key)] = item + return obj diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 99f301a..c70e7c9 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -19,6 +19,36 @@ class BasicPlugin(Plugin): for i in range(count): event.msg.reply(content) + @Plugin.command('invites') + def on_invites(self, event): + invites = event.channel.get_invites() + event.msg.reply('Channel has a total of {} invites'.format(len(invites))) + + @Plugin.command('pins') + def on_pins(self, event): + pins = event.channel.get_pins() + event.msg.reply('Channel has a total of {} pins'.format(len(pins))) + + @Plugin.command('channel stats') + def on_stats(self, event): + msg = event.msg.reply('Ok, one moment...') + invite_count = len(event.channel.get_invites()) + pin_count = len(event.channel.get_pins()) + msg_count = 0 + + print event.channel.messages_iter(bulk=True) + for msgs in event.channel.messages_iter(bulk=True): + msg_count += len(msgs) + + msg.edit('{} invites, {} pins, {} messages'.format(invite_count, pin_count, msg_count)) + + @Plugin.command('messages stack') + def on_messages_stack(self, event): + event.msg.reply('Channels: {}, messages here: ```\n{}\n```'.format( + len(self.state.messages), + '\n'.join([str(i.id) for i in self.state.messages[event.channel.id]]) + )) + if __name__ == '__main__': bot = Bot(disco_main()) bot.add_plugin(BasicPlugin)