Browse Source

Lots of API stuff, state additions, skema/modeling cleanup

pull/3/head
Andrei 9 years ago
parent
commit
37d7b3bdef
  1. 132
      disco/api/client.py
  2. 146
      disco/api/http.py
  3. 12
      disco/bot/command.py
  4. 5
      disco/bot/parser.py
  5. 2
      disco/bot/plugin.py
  6. 24
      disco/gateway/events.py
  7. 74
      disco/state.py
  8. 10
      disco/types/base.py
  9. 87
      disco/types/channel.py
  10. 25
      disco/types/guild.py
  11. 22
      disco/types/invite.py
  12. 22
      disco/types/message.py
  13. 36
      disco/util/__init__.py
  14. 20
      disco/util/types.py
  15. 30
      examples/basic_plugin.py

132
disco/api/client.py

@ -1,12 +1,15 @@
from disco.api.http import Routes, HTTPClient from disco.api.http import Routes, HTTPClient
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
from disco.types.user import User
from disco.types.message import Message from disco.types.message import Message
from disco.types.guild import Guild, GuildMember, Role
from disco.types.channel import Channel from disco.types.channel import Channel
from disco.types.invite import Invite
def optional(**kwargs): def optional(**kwargs):
return {k: v for k, v in kwargs if v is not None} return {k: v for k, v in kwargs.items() if v is not None}
class APIClient(LoggingClass): class APIClient(LoggingClass):
@ -21,34 +24,33 @@ class APIClient(LoggingClass):
return data['url'] + '?v={}&encoding={}'.format(version, encoding) return data['url'] + '?v={}&encoding={}'.format(version, encoding)
def channels_get(self, channel): def channels_get(self, channel):
r = self.http(Routes.CHANNELS_GET, channel) r = self.http(Routes.CHANNELS_GET, dict(channel=channel))
return Channel.create(self.client, r.json()) return Channel.create(self.client, r.json())
def channels_modify(self, channel, **kwargs): def channels_modify(self, channel, **kwargs):
r = self.http(Routes.CHANNELS_MODIFY, channel, json=kwargs) r = self.http(Routes.CHANNELS_MODIFY, dict(channel=channel), json=kwargs)
return Channel.create(self.client, r.json()) return Channel.create(self.client, r.json())
def channels_delete(self, channel): def channels_delete(self, channel):
r = self.http(Routes.CHANNELS_DELETE, channel) r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel))
return Channel.create(self.client, r.json()) return Channel.create(self.client, r.json())
def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50): def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50):
r = self.http(Routes.CHANNELS_MESSAGES_LIST, channel, json=optional( r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional(
channel=channel,
around=around, around=around,
before=before, before=before,
after=after, after=after,
limit=limit limit=limit
)) ))
return [Message.create(self.client, i) for i in r.json()] return Message.create_map(self.client, r.json())
def channels_messages_get(self, channel, message): def channels_messages_get(self, channel, message):
r = self.http(Routes.CHANNELS_MESSAGES_GET, channel, message) r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content, nonce=None, tts=False): def channels_messages_create(self, channel, content, nonce=None, tts=False):
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, channel, json={ r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={
'content': content, 'content': content,
'nonce': nonce, 'nonce': nonce,
'tts': tts, 'tts': tts,
@ -57,11 +59,117 @@ class APIClient(LoggingClass):
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_modify(self, channel, message, content): def channels_messages_modify(self, channel, message, content):
r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, channel, message, json={'content': content}) r = self.http(Routes.CHANNELS_MESSAGES_MODIFY,
dict(channel=channel, message=message),
json={'content': content})
return Message.create(self.client, r.json()) return Message.create(self.client, r.json())
def channels_messages_delete(self, channel, message): def channels_messages_delete(self, channel, message):
self.http(Routes.CHANNELS_MESSAGES_DELETE, channel, message) self.http(Routes.CHANNELS_MESSAGES_DELETE, dict(channel=channel, message=message))
def channels_messages_delete_bulk(self, channel, messages): def channels_messages_delete_bulk(self, channel, messages):
self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, channel, json={'messages': messages}) self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages})
def channels_permissions_modify(self, channel, permission, allow, deny, typ):
self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={
'allow': allow,
'deny': deny,
'type': typ,
})
def channels_permissions_delete(self, channel, permission):
self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission))
def channels_invites_list(self, channel):
r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel))
return Invite.create_map(self.client, r.json())
def channels_invites_create(self, channel, max_age=86400, max_uses=0, temporary=False, unique=False):
r = self.http(Routes.CHANNELS_INVITES_CREATE, dict(channel=channel), json={
'max_age': max_age,
'max_uses': max_uses,
'temporary': temporary,
'unique': unique
})
return Invite.create(self.client, r.json())
def channels_pins_list(self, channel):
r = self.http(Routes.CHANNELS_PINS_LIST, dict(channel=channel))
return Message.create_map(self.client, r.json())
def channels_pins_create(self, channel, message):
self.http(Routes.CHANNELS_PINS_CREATE, dict(channel=channel, message=message))
def channels_pins_delete(self, channel, message):
self.http(Routes.CHANNELS_PINS_DELETE, dict(channel=channel, message=message))
def guilds_get(self, guild):
r = self.http(Routes.GUILDS_GET, dict(guild=guild))
return Guild.create(self.client, r.json())
def guilds_modify(self, guild, **kwargs):
r = self.http(Routes.GUILDS_MODIFY, dict(guild=guild), json=kwargs)
return Guild.create(self.client, r.json())
def guilds_delete(self, guild):
r = self.http(Routes.GUILDS_DELETE, dict(guild=guild))
return Guild.create(self.client, r.json())
def guilds_channels_list(self, guild):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_map(self.client, r.json())
def guilds_channels_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs)
return Channel.create(self.client, r.json())
def guilds_channels_modify(self, guild, channel, position):
self.http(Routes.GUILDS_CHANNELS_MODIFY, dict(guild=guild), json={
'id': channel,
'position': position,
})
def guilds_members_list(self, guild):
r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild))
return GuildMember.create_map(self.client, r.json())
def guilds_members_get(self, guild, member):
r = self.http(Routes.GUILD_MEMBERS_GET, dict(guild=guild, member=member))
return GuildMember.create(self.client, r.json())
def guilds_members_modify(self, guild, member, **kwargs):
self.http(Routes.GUILD_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs)
def guilds_members_kick(self, guild, member):
self.http(Routes.GUILD_MEMBERS_KICK, dict(guild=guild, member=member))
def guilds_bans_list(self, guild):
r = self.http(Routes.GUILD_BANS_LIST, dict(guild=guild))
return User.create_map(self.client, r.json())
def guilds_bans_create(self, guild, user, delete_message_days):
self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={
'delete-message-days': delete_message_days,
})
def guilds_bans_delete(self, guild, user):
self.http(Routes.GUILDS_BANS_DELETE, dict(guild=guild, user=user))
def guilds_roles_list(self, guild):
r = self.http(Routes.GUILDS_ROLES_LIST, dict(guild=guild))
return Role.create_map(self.client, r.json())
def guilds_roles_create(self, guild):
r = self.http(Routes.GUILDS_ROLES_CREATE, dict(guild=guild))
return Role.create(self.client, r.json())
def guilds_roles_modify_batch(self, guild, roles):
r = self.http(Routes.GUILDS_ROLES_MODIFY_BATCH, dict(guild=guild), json=roles)
return Role.create_map(self.client, r.json())
def guilds_roles_modify(self, guild, role, **kwargs):
r = self.http(Routes.GUILDS_ROLES_MODIFY, dict(guild=guild, role=role), json=kwargs)
return Role.create(self.client, r.json())
def guilds_roles_delete(self, guild, role):
self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role))

146
disco/api/http.py

@ -1,4 +1,6 @@
import requests import requests
import random
import gevent
from holster.enum import Enum from holster.enum import Enum
@ -19,66 +21,69 @@ class Routes(object):
GATEWAY_GET = (HTTPMethod.GET, '/gateway') GATEWAY_GET = (HTTPMethod.GET, '/gateway')
# Channels # Channels
CHANNELS_GET = (HTTPMethod.GET, '/channels/{}') CHANNELS = '/channels/{channel}'
CHANNELS_MODIFY= (HTTPMethod.PATCH, '/channels/{}') CHANNELS_GET = (HTTPMethod.GET, CHANNELS)
CHANNELS_DELETE = (HTTPMethod.DELETE, '/channels/{}') CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS)
CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS)
CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, '/channels/{}/messages')
CHANNELS_MESSAGES_GET = (HTTPMethod.GET, '/channels/{}/messages/{}') CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages')
CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, '/channels/{}/messages') CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_MODFIY = (HTTPMethod.PATCH, '/channels/{}/messages/{}') CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages')
CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, '/channels/{}/messages/{}') CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, '/channels/{}/messages/bulk_delete') CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete')
CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, '/channels/{}/permissions/{}')
CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, '/channels/{}/permissions/{}') CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}')
CHANNELS_INVITES_LIST = (HTTPMethod.GET, '/channels/{}/invites') CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}')
CHANNELS_INVITES_CREATE = (HTTPMethod.POST, '/channels/{}/invites') CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites')
CHANNELS_INVITES_CREATE = (HTTPMethod.POST, CHANNELS + '/invites')
CHANNELS_PINS_LIST = (HTTPMethod.GET, '/channels/{}/pins')
CHANNELS_PINS_CREATE = (HTTPMethod.PUT, '/channels/{}/pins/{}') CHANNELS_PINS_LIST = (HTTPMethod.GET, CHANNELS + '/pins')
CHANNELS_PINS_DELETE = (HTTPMethod.DELETE, '/channels/{}/pins/{}') CHANNELS_PINS_CREATE = (HTTPMethod.PUT, CHANNELS + '/pins/{pin}')
CHANNELS_PINS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/pins/{pin}')
# Guilds # Guilds
GUILDS_GET = (HTTPMethod.GET, '/guilds/{}') GUILDS = '/guilds/{guild}'
GUILDS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}') GUILDS_GET = (HTTPMethod.GET, GUILDS)
GUILDS_DELETE = (HTTPMethod.DELETE, '/guilds/{}') GUILDS_MODIFY = (HTTPMethod.PATCH, GUILDS)
GUILDS_CHANNELS_LIST = (HTTPMethod.GET, '/guilds/{}/channels') GUILDS_DELETE = (HTTPMethod.DELETE, GUILDS)
GUILDS_CHANNELS_CREATE = (HTTPMethod.POST, '/guilds/{}/channels') GUILDS_CHANNELS_LIST = (HTTPMethod.GET, GUILDS + '/channels')
GUILDS_CHANNELS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/channels') GUILDS_CHANNELS_CREATE = (HTTPMethod.POST, GUILDS + '/channels')
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, '/guilds/{}/members') GUILDS_CHANNELS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/channels')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, '/guilds/{}/members/{}') GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/members/{}') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, '/guilds/{}/members/{}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}')
GUILDS_BANS_LIST = (HTTPMethod.GET, '/guilds/{}/bans') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, '/guilds/{}/bans/{}') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans')
GUILDS_BANS_DELETE = (HTTPMethod.DELETE, '/guilds/{}/bans/{}') GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}')
GUILDS_ROLES_LIST = (HTTPMethod.GET, '/guilds/{}/roles') GUILDS_BANS_DELETE = (HTTPMethod.DELETE, GUILDS + '/bans/{user}')
GUILDS_ROLES_CREATE = (HTTPMethod.GET, '/guilds/{}/roles') GUILDS_ROLES_LIST = (HTTPMethod.GET, GUILDS + '/roles')
GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, '/guilds/{}/roles') GUILDS_ROLES_CREATE = (HTTPMethod.GET, GUILDS + '/roles')
GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/roles/{}') GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, GUILDS + '/roles')
GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, '/guilds/{}/roles/{}') GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, GUILDS + '/roles/{role}')
GUILDS_PRUNE_COUNT = (HTTPMethod.GET, '/guilds/{}/prune') GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, GUILDS + '/roles/{role}')
GUILDS_PRUNE_BEGIN = (HTTPMethod.POST, '/guilds/{}/prune') GUILDS_PRUNE_COUNT = (HTTPMethod.GET, GUILDS + '/prune')
GUILDS_VOICE_REGIONS_LIST = (HTTPMethod.GET, '/guilds/{}/regions') GUILDS_PRUNE_BEGIN = (HTTPMethod.POST, GUILDS + '/prune')
GUILDS_INVITES_LIST = (HTTPMethod.GET, '/guilds/{}/invites') GUILDS_VOICE_REGIONS_LIST = (HTTPMethod.GET, GUILDS + '/regions')
GUILDS_INTEGRATIONS_LIST = (HTTPMethod.GET, '/guilds/{}/integrations') GUILDS_INVITES_LIST = (HTTPMethod.GET, GUILDS + '/invites')
GUILDS_INTEGRATIONS_CREATE = (HTTPMethod.POST, '/guilds/{}/integrations') GUILDS_INTEGRATIONS_LIST = (HTTPMethod.GET, GUILDS + '/integrations')
GUILDS_INTEGRATIONS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/integrations/{}') GUILDS_INTEGRATIONS_CREATE = (HTTPMethod.POST, GUILDS + '/integrations')
GUILDS_INTEGRATIONS_DELETE = (HTTPMethod.DELETE, '/guilds/{}/integrations/{}') GUILDS_INTEGRATIONS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/integrations/{integration}')
GUILDS_INTEGRATIONS_SYNC = (HTTPMethod.POST, '/guilds/{}/integrations/{}/sync') GUILDS_INTEGRATIONS_DELETE = (HTTPMethod.DELETE, GUILDS + '/integrations/{integration}')
GUILDS_EMBED_GET = (HTTPMethod.GET, '/guilds/{}/embed') GUILDS_INTEGRATIONS_SYNC = (HTTPMethod.POST, GUILDS + '/integrations/{integration}/sync')
GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/embed') GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed')
GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed')
# Users # Users
USERS_ME_GET = (HTTPMethod.GET, '/users/@me') USERS = '/users'
USERS_ME_PATCH = (HTTPMethod.PATCH, '/users/@me') USERS_ME_GET = (HTTPMethod.GET, USERS + '/@me')
USERS_ME_GUILDS_LIST = (HTTPMethod.GET, '/users/@me/guilds') USERS_ME_PATCH = (HTTPMethod.PATCH, USERS + '/@me')
USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, '/users/@me/guilds/{}') USERS_ME_GUILDS_LIST = (HTTPMethod.GET, USERS + '/@me/guilds')
USERS_ME_DMS_LIST = (HTTPMethod.GET, '/users/@me/channels') USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}')
USERS_ME_DMS_CREATE = (HTTPMethod.POST, '/users/@me/channels') USERS_ME_DMS_LIST = (HTTPMethod.GET, USERS + '/@me/channels')
USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, '/users/@me/connections') USERS_ME_DMS_CREATE = (HTTPMethod.POST, USERS + '/@me/channels')
USERS_GET = (HTTPMethod.GET, '/users/{}') USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, USERS + '/@me/connections')
USERS_GET = (HTTPMethod.GET, USERS + '/{user}')
class APIException(Exception): class APIException(Exception):
@ -89,7 +94,7 @@ class APIException(Exception):
class HTTPClient(LoggingClass): class HTTPClient(LoggingClass):
BASE_URL = 'https://discordapp.com/api' BASE_URL = 'https://discordapp.com/api/v6'
MAX_RETRIES = 5 MAX_RETRIES = 5
def __init__(self, token): def __init__(self, token):
@ -100,7 +105,8 @@ class HTTPClient(LoggingClass):
'Authorization': 'Bot ' + token, 'Authorization': 'Bot ' + token,
} }
def __call__(self, route, *args, **kwargs): def __call__(self, route, args=None, **kwargs):
args = args or {}
retry = kwargs.pop('retry_number', 0) retry = kwargs.pop('retry_number', 0)
# Merge or set headers # Merge or set headers
@ -109,17 +115,20 @@ class HTTPClient(LoggingClass):
else: else:
kwargs['headers'] = self.headers kwargs['headers'] = self.headers
# Compile URL args # Build the bucket URL
compiled = (str(route[0]), (self.BASE_URL) + route[1].format(*args)) filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in args.items()}
bucket = (route[0].value, route[1].format(**filtered))
# Possibly wait if we're rate limited # Possibly wait if we're rate limited
self.limiter.check(compiled) self.limiter.check(bucket)
# Make the actual request # Make the actual request
r = requests.request(compiled[0], compiled[1], **kwargs) url = self.BASE_URL + route[1].format(**args)
print route[0].value, url, kwargs
r = requests.request(route[0].value, url, **kwargs)
# Update rate limiter # Update rate limiter
self.limiter.update(compiled, r) self.limiter.update(bucket, r)
# If we got a success status code, just return the data # If we got a success status code, just return the data
if r.status_code < 400: if r.status_code < 400:
@ -134,5 +143,14 @@ class HTTPClient(LoggingClass):
self.log.error('Failing request, hit max retries') self.log.error('Failing request, hit max retries')
raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content)
backoff = self.random_backoff()
self.log.warning('Request to `{}` failed with code {}, retrying after {}s'.format(url, r.status_code, backoff))
gevent.sleep(backoff)
# Otherwise just recurse and try again # Otherwise just recurse and try again
return self(route, retry_number=retry, *args, **kwargs) return self(route, args, retry_number=retry, **kwargs)
@staticmethod
def random_backoff():
# 500 milliseconds to 5 seconds)
return random.randint(500, 5000) / 1000.0

12
disco/bot/command.py

@ -15,6 +15,18 @@ class CommandEvent(object):
self.name = self.match.group(1) self.name = self.match.group(1)
self.args = self.match.group(2).strip().split(' ') self.args = self.match.group(2).strip().split(' ')
@property
def channel(self):
return self.msg.channel
@property
def guild(self):
return self.msg.guild
@property
def actor(self):
return self.msg.author
class CommandError(Exception): class CommandError(Exception):
pass pass

5
disco/bot/parser.py

@ -77,6 +77,9 @@ class ArgumentSet(object):
if not arg.required and index + arg.true_count <= len(rawargs): if not arg.required and index + arg.true_count <= len(rawargs):
continue continue
if arg.count == 0:
raw = rawargs[index:]
else:
raw = rawargs[index:index + arg.true_count] raw = rawargs[index:index + arg.true_count]
if arg.types: if arg.types:
@ -88,7 +91,7 @@ class ArgumentSet(object):
r, ', '.join(arg.types) r, ', '.join(arg.types)
)) ))
if arg.true_count == 1: if arg.count == 1:
raw = raw[0] raw = raw[0]
if not arg.types or arg.types == ['str'] and isinstance(raw, list): if not arg.types or arg.types == ['str'] and isinstance(raw, list):

2
disco/bot/plugin.py

@ -61,6 +61,8 @@ class Plugin(LoggingClass, PluginDeco):
def __init__(self, bot, config): def __init__(self, bot, config):
super(Plugin, self).__init__() super(Plugin, self).__init__()
self.bot = bot self.bot = bot
self.client = bot.client
self.state = bot.client.state
self.config = config self.config = config
self.listeners = [] self.listeners = []

24
disco/gateway/events.py

@ -1,8 +1,7 @@
import inflection import inflection
import skema import skema
from disco.util import recursive_find_matching from disco.util import skema_find_recursive_by_type
from disco.types.base import BaseType
from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState
@ -15,8 +14,7 @@ class GatewayEvent(skema.Model):
obj = cls.create(obj['d']) obj = cls.create(obj['d'])
# TODO: use skema info for item in skema_find_recursive_by_type(obj, skema.ModelType):
for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)):
item.client = client item.client = client
return obj return obj
@ -68,13 +66,17 @@ class GuildDelete(GatewayEvent):
class ChannelCreate(Sub('channel')): class ChannelCreate(Sub('channel')):
channel = skema.ModelType(Channel) channel = skema.ModelType(Channel)
@property
def guild(self):
return self.channel.guild
class ChannelUpdate(Sub('channel')):
channel = skema.ModelType(Channel)
class ChannelUpdate(ChannelCreate):
pass
class ChannelDelete(Sub('channel')):
channel = skema.ModelType(Channel) class ChannelDelete(ChannelCreate):
pass
class ChannelPinsUpdate(GatewayEvent): class ChannelPinsUpdate(GatewayEvent):
@ -136,8 +138,12 @@ class GuildRoleDelete(GatewayEvent):
class MessageCreate(Sub('message')): class MessageCreate(Sub('message')):
message = skema.ModelType(Message) message = skema.ModelType(Message)
@property
def channel(self):
return self.message.channel
class MessageUpdate(Sub('message')): class MessageUpdate(MessageCreate):
message = skema.ModelType(Message) message = skema.ModelType(Message)

74
disco/state.py

@ -1,16 +1,36 @@
from collections import defaultdict, deque, namedtuple
from weakref import WeakValueDictionary
StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])
class StateConfig(object):
# Whether to keep a buffer of messages
track_messages = True
# The number maximum number of messages to store
track_messages_size = 100
class State(object): class State(object):
def __init__(self, client): def __init__(self, client, config=None):
self.client = client self.client = client
self.config = config or StateConfig()
self.me = None self.me = None
self.channels = {} self.dms = {}
self.guilds = {} self.guilds = {}
self.channels = WeakValueDictionary()
self.client.events.on('Ready', self.on_ready) self.client.events.on('Ready', self.on_ready)
self.messages_stack = defaultdict(lambda: deque(maxlen=self.config.track_messages_size))
if self.config.track_messages:
self.client.events.on('MessageCreate', self.on_message_create)
self.client.events.on('MessageDelete', self.on_message_delete)
# Guilds # Guilds
self.client.events.on('GuildCreate', self.on_guild_create) self.client.events.on('GuildCreate', self.on_guild_create)
self.client.events.on('GuildUpdate', self.on_guild_update) self.client.events.on('GuildUpdate', self.on_guild_update)
@ -24,27 +44,63 @@ class State(object):
def on_ready(self, event): def on_ready(self, event):
self.me = event.user self.me = event.user
def on_message_create(self, event):
self.messages_stack[event.message.channel_id].append(
StackMessage(event.message.id, event.message.channel_id, event.message.author.id))
def on_message_update(self, event):
message, cid = event.message, event.message.channel_id
if cid not in self.messages_stack:
return
sm = next((i for i in self.messages_stack[cid] if i.id == message.id), None)
if not sm:
return
sm.id = message.id
sm.channel_id = cid
sm.author_id = message.author.id
def on_message_delete(self, event):
if event.channel_id not in self.messages_stack:
return
sm = next((i for i in self.messages_stack[event.channel_id] if i.id == event.id), None)
if not sm:
return
self.messages_stack[event.channel_id].remove(sm)
def on_guild_create(self, event): def on_guild_create(self, event):
self.guilds[event.guild.id] = event.guild self.guilds[event.guild.id] = event.guild
self.channels.update(event.guild.channels)
for channel in event.guild.channels:
self.channels[channel.id] = channel
def on_guild_update(self, event): def on_guild_update(self, event):
self.guilds[event.guild.id] = event.guild self.guilds[event.guild.id] = event.guild
def on_guild_delete(self, event): def on_guild_delete(self, event):
if event.guild_id in self.guilds: if event.guild_id in self.guilds:
# Just delete the guild, channel references will fall
del self.guilds[event.guild_id] del self.guilds[event.guild_id]
# CHANNELS?
def on_channel_create(self, event): def on_channel_create(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds:
self.guilds[event.channel.guild_id].channels[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
elif event.channel.is_dm:
self.dms[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel self.channels[event.channel.id] = event.channel
def on_channel_update(self, event): def on_channel_update(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds:
self.guilds[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
elif event.channel.is_dm:
self.dms[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel self.channels[event.channel.id] = event.channel
def on_channel_delete(self, event): def on_channel_delete(self, event):
if event.channel.id in self.channels: if event.channel.is_guild and event.channel.guild_id in self.guilds:
del self.channels[event.channel.id] del self.guilds[event.channel.id]
elif event.channel.is_dm:
del self.pms[event.channel.id]

10
disco/types/base.py

@ -1,6 +1,7 @@
import skema import skema
import functools
from disco.util import recursive_find_matching from disco.util import skema_find_recursive_by_type
class BaseType(skema.Model): class BaseType(skema.Model):
@ -11,9 +12,12 @@ class BaseType(skema.Model):
# Valdiate # Valdiate
obj.validate() obj.validate()
# TODO: this can be smarter using skema metadata for item in skema_find_recursive_by_type(obj, skema.ModelType):
for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)):
item.client = client item.client = client
obj.client = client obj.client = client
return obj return obj
@classmethod
def create_map(cls, client, data):
return map(functools.partial(cls.create, client), data)

87
disco/types/channel.py

@ -3,6 +3,7 @@ import skema
from holster.enum import Enum from holster.enum import Enum
from disco.util.cache import cached_property from disco.util.cache import cached_property
from disco.util.types import ListToDictType
from disco.types.base import BaseType from disco.types.base import BaseType
from disco.types.user import User from disco.types.user import User
@ -34,18 +35,100 @@ class Channel(BaseType):
name = skema.StringType() name = skema.StringType()
topic = skema.StringType() topic = skema.StringType()
last_message_id = skema.SnowflakeType() _last_message_id = skema.SnowflakeType(stored_name='last_message_id')
position = skema.IntType() position = skema.IntType()
bitrate = skema.IntType(required=False) bitrate = skema.IntType(required=False)
recipient = skema.ModelType(User, required=False) recipient = skema.ModelType(User, required=False)
type = skema.IntType(choices=ChannelType.ALL_VALUES) type = skema.IntType(choices=ChannelType.ALL_VALUES)
permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite)) overwrites = ListToDictType('id', skema.ModelType(PermissionOverwrite), stored_name='permission_overwrites')
@property
def is_guild(self):
return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE)
@property
def is_dm(self):
return self.type in (ChannelType.DM, ChannelType.GROUP_DM)
@property
def last_message_id(self):
if self.id not in self.client.state.messages_stack:
return self._last_message_id
return self.client.state.messages_stack[self.id][-1].id
@property
def messages(self):
return self.messages_iter()
def messages_iter(self, **kwargs):
return MessageIterator(self.client, self.id, before=self.last_message_id, **kwargs)
@cached_property @cached_property
def guild(self): def guild(self):
return self.client.state.guilds.get(self.guild_id) return self.client.state.guilds.get(self.guild_id)
def get_invites(self):
return self.client.api.channels_invites_list(self.id)
def get_pins(self):
return self.client.api.channels_pins_list(self.id)
def send_message(self, content, nonce=None, tts=False): def send_message(self, content, nonce=None, tts=False):
return self.client.api.channels_messages_create(self.id, content, nonce, tts) return self.client.api.channels_messages_create(self.id, content, nonce, tts)
class MessageIterator(object):
Direction = Enum('UP', 'DOWN')
def __init__(self, client, channel, direction=Direction.UP, bulk=False, before=None, after=None, chunk_size=100):
self.client = client
self.channel = channel
self.direction = direction
self.bulk = bulk
self.before = before
self.after = after
self.chunk_size = chunk_size
self.last = None
self._buffer = []
if len(filter(bool, (before, after))) > 1:
raise Exception('Must specify at most one of before or after')
if not any((before, after)) and self.direction == self.Direction.DOWN:
raise Exception('Must specify either before or after for downward seeking')
def fill(self):
self._buffer = self.client.api.channels_messages_list(
self.channel,
before=self.before,
after=self.after,
limit=self.chunk_size)
if not len(self._buffer):
raise StopIteration
self.after = None
self.before = None
if self.direction == self.Direction.UP:
self.before = self._buffer[-1].id
else:
self._buffer.reverse()
self.after == self._buffer[-1].id
def __iter__(self):
return self
def next(self):
if not len(self._buffer):
self.fill()
if self.bulk:
res = self._buffer
self._buffer = []
return res
else:
return self._buffer.pop()

25
disco/types/guild.py

@ -1,8 +1,7 @@
import skema import skema
from disco.util.cache import cached_property
from disco.types.base import BaseType from disco.types.base import BaseType
from disco.util.types import PreHookType from disco.util.types import PreHookType, ListToDictType
from disco.types.user import User from disco.types.user import User
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.channel import Channel from disco.types.channel import Channel
@ -33,6 +32,10 @@ class GuildMember(BaseType):
joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())
roles = skema.ListType(skema.SnowflakeType()) roles = skema.ListType(skema.SnowflakeType())
@property
def id(self):
return self.user.id
class Guild(BaseType): class Guild(BaseType):
id = skema.SnowflakeType() id = skema.SnowflakeType()
@ -53,20 +56,16 @@ class Guild(BaseType):
features = skema.ListType(skema.StringType()) features = skema.ListType(skema.StringType())
members = skema.ListType(skema.ModelType(GuildMember)) members = ListToDictType('id', skema.ModelType(GuildMember))
voice_states = skema.ListType(skema.ModelType(VoiceState)) channels = ListToDictType('id', skema.ModelType(Channel))
channels = skema.ListType(skema.ModelType(Channel)) roles = ListToDictType('id', skema.ModelType(Role))
roles = skema.ListType(skema.ModelType(Role)) emojis = ListToDictType('id', skema.ModelType(Emoji))
emojis = skema.ListType(skema.ModelType(Emoji)) voice_states = ListToDictType('id', skema.ModelType(VoiceState))
@cached_property
def members_dict(self):
return {i.user.id: i for i in self.members}
def get_member(self, user): def get_member(self, user):
return self.members_dict.get(user.id) return self.members.get(user.id)
def validate_channels(self, ctx): def validate_channels(self, ctx):
if self.channels: if self.channels:
for channel in self.channels: for channel in self.channels.values():
channel.guild_id = self.id channel.guild_id = self.id

22
disco/types/invite.py

@ -0,0 +1,22 @@
import skema
from disco.util.types import PreHookType
from disco.types.base import BaseType
from disco.types.user import User
from disco.types.guild import Guild
from disco.types.channel import Channel
class Invite(BaseType):
code = skema.StringType()
inviter = skema.ModelType(User)
guild = skema.ModelType(Guild)
channel = skema.ModelType(Channel)
max_age = skema.IntType()
max_uses = skema.IntType()
uses = skema.IntType()
temporary = skema.BooleanType()
created_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())

22
disco/types/message.py

@ -2,7 +2,7 @@ import re
import skema import skema
from disco.util.cache import cached_property from disco.util.cache import cached_property
from disco.util.types import PreHookType from disco.util.types import PreHookType, ListToDictType
from disco.types.base import BaseType from disco.types.base import BaseType
from disco.types.user import User from disco.types.user import User
from disco.types.guild import Role from disco.types.guild import Role
@ -41,11 +41,11 @@ class Message(BaseType):
pinned = skema.BooleanType(required=False) pinned = skema.BooleanType(required=False)
mentions = skema.ListType(skema.ModelType(User)) mentions = ListToDictType('id', skema.ModelType(User))
mention_roles = skema.ListType(skema.SnowflakeType()) mention_roles = skema.ListType(skema.SnowflakeType())
embeds = skema.ListType(skema.ModelType(MessageEmbed)) embeds = skema.ListType(skema.ModelType(MessageEmbed))
attachment = skema.ListType(skema.ModelType(MessageAttachment)) attachments = ListToDictType('id', skema.ModelType(MessageAttachment))
@cached_property @cached_property
def guild(self): def guild(self):
@ -55,14 +55,6 @@ class Message(BaseType):
def channel(self): def channel(self):
return self.client.state.channels.get(self.channel_id) return self.client.state.channels.get(self.channel_id)
@cached_property
def mention_users(self):
return [i.id for i in self.mentions]
@cached_property
def mention_users_dict(self):
return {i.id: i for i in self.mentions}
def reply(self, *args, **kwargs): def reply(self, *args, **kwargs):
return self.channel.send_message(*args, **kwargs) return self.channel.send_message(*args, **kwargs)
@ -74,11 +66,13 @@ class Message(BaseType):
def is_mentioned(self, entity): def is_mentioned(self, entity):
if isinstance(entity, User): if isinstance(entity, User):
return entity.id in self.mention_users return entity.id in self.mentions
elif isinstance(entity, Role): elif isinstance(entity, Role):
return entity.id in self.mention_roles return entity.id in self.mention_roles
elif isinstance(entity, long):
return entity in self.mentions or entity in self.mention_roles
else: else:
raise Exception('Unknown entity: {}'.format(entity)) raise Exception('Unknown entity: {} ({})'.format(entity, type(entity)))
@cached_property @cached_property
def without_mentions(self): def without_mentions(self):
@ -95,6 +89,6 @@ class Message(BaseType):
if id in self.mention_roles: if id in self.mention_roles:
return role_replace(id) return role_replace(id)
else: else:
return user_replace(self.mention_users_dict.get(id)) return user_replace(self.mentions.get(id))
return re.sub('<@!?([0-9]+)>', replace, self.content) return re.sub('<@!?([0-9]+)>', replace, self.content)

36
disco/util/__init__.py

@ -1,18 +1,36 @@
import skema
def recursive_find_matching(base, match_clause): def _recurse(typ, field, value):
result = [] result = []
if hasattr(base, '__dict__'): if isinstance(field, skema.ModelType):
values = base.__dict__.values() result += skema_find_recursive_by_type(value, typ)
else:
values = list(base)
for v in values: if isinstance(field, (skema.ListType, skema.SetType, skema.DictType)):
if match_clause(v): if isinstance(field, skema.DictType):
value = value.values()
for item in value:
if isinstance(field.field, typ):
result.append(item)
result += _recurse(typ, field.field, item)
return result
def skema_find_recursive_by_type(base, typ):
result = []
for name, field in base._fields_by_stored_name.items():
v = getattr(base, name, None)
if not v:
continue
if isinstance(field, typ):
result.append(v) result.append(v)
if hasattr(v, '__dict__') or hasattr(v, '__iter__'): result += _recurse(typ, field, v)
result += recursive_find_matching(v, match_clause)
return result return result

20
disco/util/types.py

@ -1,4 +1,4 @@
from skema import BaseType from skema import BaseType, DictType
class PreHookType(BaseType): class PreHookType(BaseType):
@ -16,3 +16,21 @@ class PreHookType(BaseType):
def to_storage(self, *args, **kwargs): def to_storage(self, *args, **kwargs):
return self.field.to_storage(*args, **kwargs) return self.field.to_storage(*args, **kwargs)
class ListToDictType(DictType):
def __init__(self, key, *args, **kwargs):
super(ListToDictType, self).__init__(*args, **kwargs)
self.key = key
def to_python(self, value):
if not value:
return {}
to_python = self.field.to_python
obj = {}
for item in value:
item = to_python(item)
obj[getattr(item, self.key)] = item
return obj

30
examples/basic_plugin.py

@ -19,6 +19,36 @@ class BasicPlugin(Plugin):
for i in range(count): for i in range(count):
event.msg.reply(content) event.msg.reply(content)
@Plugin.command('invites')
def on_invites(self, event):
invites = event.channel.get_invites()
event.msg.reply('Channel has a total of {} invites'.format(len(invites)))
@Plugin.command('pins')
def on_pins(self, event):
pins = event.channel.get_pins()
event.msg.reply('Channel has a total of {} pins'.format(len(pins)))
@Plugin.command('channel stats')
def on_stats(self, event):
msg = event.msg.reply('Ok, one moment...')
invite_count = len(event.channel.get_invites())
pin_count = len(event.channel.get_pins())
msg_count = 0
print event.channel.messages_iter(bulk=True)
for msgs in event.channel.messages_iter(bulk=True):
msg_count += len(msgs)
msg.edit('{} invites, {} pins, {} messages'.format(invite_count, pin_count, msg_count))
@Plugin.command('messages stack')
def on_messages_stack(self, event):
event.msg.reply('Channels: {}, messages here: ```\n{}\n```'.format(
len(self.state.messages),
'\n'.join([str(i.id) for i in self.state.messages[event.channel.id]])
))
if __name__ == '__main__': if __name__ == '__main__':
bot = Bot(disco_main()) bot = Bot(disco_main())
bot.add_plugin(BasicPlugin) bot.add_plugin(BasicPlugin)

Loading…
Cancel
Save