Browse Source

Lots of API stuff, state additions, skema/modeling cleanup

pull/3/head
Andrei 9 years ago
parent
commit
37d7b3bdef
  1. 132
      disco/api/client.py
  2. 146
      disco/api/http.py
  3. 12
      disco/bot/command.py
  4. 5
      disco/bot/parser.py
  5. 2
      disco/bot/plugin.py
  6. 24
      disco/gateway/events.py
  7. 74
      disco/state.py
  8. 10
      disco/types/base.py
  9. 87
      disco/types/channel.py
  10. 25
      disco/types/guild.py
  11. 22
      disco/types/invite.py
  12. 22
      disco/types/message.py
  13. 36
      disco/util/__init__.py
  14. 20
      disco/util/types.py
  15. 30
      examples/basic_plugin.py

132
disco/api/client.py

@ -1,12 +1,15 @@
from disco.api.http import Routes, HTTPClient
from disco.util.logging import LoggingClass
from disco.types.user import User
from disco.types.message import Message
from disco.types.guild import Guild, GuildMember, Role
from disco.types.channel import Channel
from disco.types.invite import Invite
def optional(**kwargs):
return {k: v for k, v in kwargs if v is not None}
return {k: v for k, v in kwargs.items() if v is not None}
class APIClient(LoggingClass):
@ -21,34 +24,33 @@ class APIClient(LoggingClass):
return data['url'] + '?v={}&encoding={}'.format(version, encoding)
def channels_get(self, channel):
r = self.http(Routes.CHANNELS_GET, channel)
r = self.http(Routes.CHANNELS_GET, dict(channel=channel))
return Channel.create(self.client, r.json())
def channels_modify(self, channel, **kwargs):
r = self.http(Routes.CHANNELS_MODIFY, channel, json=kwargs)
r = self.http(Routes.CHANNELS_MODIFY, dict(channel=channel), json=kwargs)
return Channel.create(self.client, r.json())
def channels_delete(self, channel):
r = self.http(Routes.CHANNELS_DELETE, channel)
r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel))
return Channel.create(self.client, r.json())
def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50):
r = self.http(Routes.CHANNELS_MESSAGES_LIST, channel, json=optional(
channel=channel,
r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional(
around=around,
before=before,
after=after,
limit=limit
))
return [Message.create(self.client, i) for i in r.json()]
return Message.create_map(self.client, r.json())
def channels_messages_get(self, channel, message):
r = self.http(Routes.CHANNELS_MESSAGES_GET, channel, message)
r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content, nonce=None, tts=False):
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, channel, json={
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={
'content': content,
'nonce': nonce,
'tts': tts,
@ -57,11 +59,117 @@ class APIClient(LoggingClass):
return Message.create(self.client, r.json())
def channels_messages_modify(self, channel, message, content):
r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, channel, message, json={'content': content})
r = self.http(Routes.CHANNELS_MESSAGES_MODIFY,
dict(channel=channel, message=message),
json={'content': content})
return Message.create(self.client, r.json())
def channels_messages_delete(self, channel, message):
self.http(Routes.CHANNELS_MESSAGES_DELETE, channel, message)
self.http(Routes.CHANNELS_MESSAGES_DELETE, dict(channel=channel, message=message))
def channels_messages_delete_bulk(self, channel, messages):
self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, channel, json={'messages': messages})
self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages})
def channels_permissions_modify(self, channel, permission, allow, deny, typ):
self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={
'allow': allow,
'deny': deny,
'type': typ,
})
def channels_permissions_delete(self, channel, permission):
self.http(Routes.CHANNELS_PERMISSIONS_DELETE, dict(channel=channel, permission=permission))
def channels_invites_list(self, channel):
r = self.http(Routes.CHANNELS_INVITES_LIST, dict(channel=channel))
return Invite.create_map(self.client, r.json())
def channels_invites_create(self, channel, max_age=86400, max_uses=0, temporary=False, unique=False):
r = self.http(Routes.CHANNELS_INVITES_CREATE, dict(channel=channel), json={
'max_age': max_age,
'max_uses': max_uses,
'temporary': temporary,
'unique': unique
})
return Invite.create(self.client, r.json())
def channels_pins_list(self, channel):
r = self.http(Routes.CHANNELS_PINS_LIST, dict(channel=channel))
return Message.create_map(self.client, r.json())
def channels_pins_create(self, channel, message):
self.http(Routes.CHANNELS_PINS_CREATE, dict(channel=channel, message=message))
def channels_pins_delete(self, channel, message):
self.http(Routes.CHANNELS_PINS_DELETE, dict(channel=channel, message=message))
def guilds_get(self, guild):
r = self.http(Routes.GUILDS_GET, dict(guild=guild))
return Guild.create(self.client, r.json())
def guilds_modify(self, guild, **kwargs):
r = self.http(Routes.GUILDS_MODIFY, dict(guild=guild), json=kwargs)
return Guild.create(self.client, r.json())
def guilds_delete(self, guild):
r = self.http(Routes.GUILDS_DELETE, dict(guild=guild))
return Guild.create(self.client, r.json())
def guilds_channels_list(self, guild):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_map(self.client, r.json())
def guilds_channels_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs)
return Channel.create(self.client, r.json())
def guilds_channels_modify(self, guild, channel, position):
self.http(Routes.GUILDS_CHANNELS_MODIFY, dict(guild=guild), json={
'id': channel,
'position': position,
})
def guilds_members_list(self, guild):
r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild))
return GuildMember.create_map(self.client, r.json())
def guilds_members_get(self, guild, member):
r = self.http(Routes.GUILD_MEMBERS_GET, dict(guild=guild, member=member))
return GuildMember.create(self.client, r.json())
def guilds_members_modify(self, guild, member, **kwargs):
self.http(Routes.GUILD_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs)
def guilds_members_kick(self, guild, member):
self.http(Routes.GUILD_MEMBERS_KICK, dict(guild=guild, member=member))
def guilds_bans_list(self, guild):
r = self.http(Routes.GUILD_BANS_LIST, dict(guild=guild))
return User.create_map(self.client, r.json())
def guilds_bans_create(self, guild, user, delete_message_days):
self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={
'delete-message-days': delete_message_days,
})
def guilds_bans_delete(self, guild, user):
self.http(Routes.GUILDS_BANS_DELETE, dict(guild=guild, user=user))
def guilds_roles_list(self, guild):
r = self.http(Routes.GUILDS_ROLES_LIST, dict(guild=guild))
return Role.create_map(self.client, r.json())
def guilds_roles_create(self, guild):
r = self.http(Routes.GUILDS_ROLES_CREATE, dict(guild=guild))
return Role.create(self.client, r.json())
def guilds_roles_modify_batch(self, guild, roles):
r = self.http(Routes.GUILDS_ROLES_MODIFY_BATCH, dict(guild=guild), json=roles)
return Role.create_map(self.client, r.json())
def guilds_roles_modify(self, guild, role, **kwargs):
r = self.http(Routes.GUILDS_ROLES_MODIFY, dict(guild=guild, role=role), json=kwargs)
return Role.create(self.client, r.json())
def guilds_roles_delete(self, guild, role):
self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role))

146
disco/api/http.py

@ -1,4 +1,6 @@
import requests
import random
import gevent
from holster.enum import Enum
@ -19,66 +21,69 @@ class Routes(object):
GATEWAY_GET = (HTTPMethod.GET, '/gateway')
# Channels
CHANNELS_GET = (HTTPMethod.GET, '/channels/{}')
CHANNELS_MODIFY= (HTTPMethod.PATCH, '/channels/{}')
CHANNELS_DELETE = (HTTPMethod.DELETE, '/channels/{}')
CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, '/channels/{}/messages')
CHANNELS_MESSAGES_GET = (HTTPMethod.GET, '/channels/{}/messages/{}')
CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, '/channels/{}/messages')
CHANNELS_MESSAGES_MODFIY = (HTTPMethod.PATCH, '/channels/{}/messages/{}')
CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, '/channels/{}/messages/{}')
CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, '/channels/{}/messages/bulk_delete')
CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, '/channels/{}/permissions/{}')
CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, '/channels/{}/permissions/{}')
CHANNELS_INVITES_LIST = (HTTPMethod.GET, '/channels/{}/invites')
CHANNELS_INVITES_CREATE = (HTTPMethod.POST, '/channels/{}/invites')
CHANNELS_PINS_LIST = (HTTPMethod.GET, '/channels/{}/pins')
CHANNELS_PINS_CREATE = (HTTPMethod.PUT, '/channels/{}/pins/{}')
CHANNELS_PINS_DELETE = (HTTPMethod.DELETE, '/channels/{}/pins/{}')
CHANNELS = '/channels/{channel}'
CHANNELS_GET = (HTTPMethod.GET, CHANNELS)
CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS)
CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS)
CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages')
CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages')
CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete')
CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}')
CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}')
CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites')
CHANNELS_INVITES_CREATE = (HTTPMethod.POST, CHANNELS + '/invites')
CHANNELS_PINS_LIST = (HTTPMethod.GET, CHANNELS + '/pins')
CHANNELS_PINS_CREATE = (HTTPMethod.PUT, CHANNELS + '/pins/{pin}')
CHANNELS_PINS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/pins/{pin}')
# Guilds
GUILDS_GET = (HTTPMethod.GET, '/guilds/{}')
GUILDS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}')
GUILDS_DELETE = (HTTPMethod.DELETE, '/guilds/{}')
GUILDS_CHANNELS_LIST = (HTTPMethod.GET, '/guilds/{}/channels')
GUILDS_CHANNELS_CREATE = (HTTPMethod.POST, '/guilds/{}/channels')
GUILDS_CHANNELS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/channels')
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, '/guilds/{}/members')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, '/guilds/{}/members/{}')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/members/{}')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, '/guilds/{}/members/{}')
GUILDS_BANS_LIST = (HTTPMethod.GET, '/guilds/{}/bans')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, '/guilds/{}/bans/{}')
GUILDS_BANS_DELETE = (HTTPMethod.DELETE, '/guilds/{}/bans/{}')
GUILDS_ROLES_LIST = (HTTPMethod.GET, '/guilds/{}/roles')
GUILDS_ROLES_CREATE = (HTTPMethod.GET, '/guilds/{}/roles')
GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, '/guilds/{}/roles')
GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/roles/{}')
GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, '/guilds/{}/roles/{}')
GUILDS_PRUNE_COUNT = (HTTPMethod.GET, '/guilds/{}/prune')
GUILDS_PRUNE_BEGIN = (HTTPMethod.POST, '/guilds/{}/prune')
GUILDS_VOICE_REGIONS_LIST = (HTTPMethod.GET, '/guilds/{}/regions')
GUILDS_INVITES_LIST = (HTTPMethod.GET, '/guilds/{}/invites')
GUILDS_INTEGRATIONS_LIST = (HTTPMethod.GET, '/guilds/{}/integrations')
GUILDS_INTEGRATIONS_CREATE = (HTTPMethod.POST, '/guilds/{}/integrations')
GUILDS_INTEGRATIONS_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/integrations/{}')
GUILDS_INTEGRATIONS_DELETE = (HTTPMethod.DELETE, '/guilds/{}/integrations/{}')
GUILDS_INTEGRATIONS_SYNC = (HTTPMethod.POST, '/guilds/{}/integrations/{}/sync')
GUILDS_EMBED_GET = (HTTPMethod.GET, '/guilds/{}/embed')
GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, '/guilds/{}/embed')
GUILDS = '/guilds/{guild}'
GUILDS_GET = (HTTPMethod.GET, GUILDS)
GUILDS_MODIFY = (HTTPMethod.PATCH, GUILDS)
GUILDS_DELETE = (HTTPMethod.DELETE, GUILDS)
GUILDS_CHANNELS_LIST = (HTTPMethod.GET, GUILDS + '/channels')
GUILDS_CHANNELS_CREATE = (HTTPMethod.POST, GUILDS + '/channels')
GUILDS_CHANNELS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/channels')
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}')
GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}')
GUILDS_BANS_DELETE = (HTTPMethod.DELETE, GUILDS + '/bans/{user}')
GUILDS_ROLES_LIST = (HTTPMethod.GET, GUILDS + '/roles')
GUILDS_ROLES_CREATE = (HTTPMethod.GET, GUILDS + '/roles')
GUILDS_ROLES_MODIFY_BATCH = (HTTPMethod.PATCH, GUILDS + '/roles')
GUILDS_ROLES_MODIFY = (HTTPMethod.PATCH, GUILDS + '/roles/{role}')
GUILDS_ROLES_DELETE = (HTTPMethod.DELETE, GUILDS + '/roles/{role}')
GUILDS_PRUNE_COUNT = (HTTPMethod.GET, GUILDS + '/prune')
GUILDS_PRUNE_BEGIN = (HTTPMethod.POST, GUILDS + '/prune')
GUILDS_VOICE_REGIONS_LIST = (HTTPMethod.GET, GUILDS + '/regions')
GUILDS_INVITES_LIST = (HTTPMethod.GET, GUILDS + '/invites')
GUILDS_INTEGRATIONS_LIST = (HTTPMethod.GET, GUILDS + '/integrations')
GUILDS_INTEGRATIONS_CREATE = (HTTPMethod.POST, GUILDS + '/integrations')
GUILDS_INTEGRATIONS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/integrations/{integration}')
GUILDS_INTEGRATIONS_DELETE = (HTTPMethod.DELETE, GUILDS + '/integrations/{integration}')
GUILDS_INTEGRATIONS_SYNC = (HTTPMethod.POST, GUILDS + '/integrations/{integration}/sync')
GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed')
GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed')
# Users
USERS_ME_GET = (HTTPMethod.GET, '/users/@me')
USERS_ME_PATCH = (HTTPMethod.PATCH, '/users/@me')
USERS_ME_GUILDS_LIST = (HTTPMethod.GET, '/users/@me/guilds')
USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, '/users/@me/guilds/{}')
USERS_ME_DMS_LIST = (HTTPMethod.GET, '/users/@me/channels')
USERS_ME_DMS_CREATE = (HTTPMethod.POST, '/users/@me/channels')
USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, '/users/@me/connections')
USERS_GET = (HTTPMethod.GET, '/users/{}')
USERS = '/users'
USERS_ME_GET = (HTTPMethod.GET, USERS + '/@me')
USERS_ME_PATCH = (HTTPMethod.PATCH, USERS + '/@me')
USERS_ME_GUILDS_LIST = (HTTPMethod.GET, USERS + '/@me/guilds')
USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}')
USERS_ME_DMS_LIST = (HTTPMethod.GET, USERS + '/@me/channels')
USERS_ME_DMS_CREATE = (HTTPMethod.POST, USERS + '/@me/channels')
USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, USERS + '/@me/connections')
USERS_GET = (HTTPMethod.GET, USERS + '/{user}')
class APIException(Exception):
@ -89,7 +94,7 @@ class APIException(Exception):
class HTTPClient(LoggingClass):
BASE_URL = 'https://discordapp.com/api'
BASE_URL = 'https://discordapp.com/api/v6'
MAX_RETRIES = 5
def __init__(self, token):
@ -100,7 +105,8 @@ class HTTPClient(LoggingClass):
'Authorization': 'Bot ' + token,
}
def __call__(self, route, *args, **kwargs):
def __call__(self, route, args=None, **kwargs):
args = args or {}
retry = kwargs.pop('retry_number', 0)
# Merge or set headers
@ -109,17 +115,20 @@ class HTTPClient(LoggingClass):
else:
kwargs['headers'] = self.headers
# Compile URL args
compiled = (str(route[0]), (self.BASE_URL) + route[1].format(*args))
# Build the bucket URL
filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in args.items()}
bucket = (route[0].value, route[1].format(**filtered))
# Possibly wait if we're rate limited
self.limiter.check(compiled)
self.limiter.check(bucket)
# Make the actual request
r = requests.request(compiled[0], compiled[1], **kwargs)
url = self.BASE_URL + route[1].format(**args)
print route[0].value, url, kwargs
r = requests.request(route[0].value, url, **kwargs)
# Update rate limiter
self.limiter.update(compiled, r)
self.limiter.update(bucket, r)
# If we got a success status code, just return the data
if r.status_code < 400:
@ -134,5 +143,14 @@ class HTTPClient(LoggingClass):
self.log.error('Failing request, hit max retries')
raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content)
backoff = self.random_backoff()
self.log.warning('Request to `{}` failed with code {}, retrying after {}s'.format(url, r.status_code, backoff))
gevent.sleep(backoff)
# Otherwise just recurse and try again
return self(route, retry_number=retry, *args, **kwargs)
return self(route, args, retry_number=retry, **kwargs)
@staticmethod
def random_backoff():
# 500 milliseconds to 5 seconds)
return random.randint(500, 5000) / 1000.0

12
disco/bot/command.py

@ -15,6 +15,18 @@ class CommandEvent(object):
self.name = self.match.group(1)
self.args = self.match.group(2).strip().split(' ')
@property
def channel(self):
return self.msg.channel
@property
def guild(self):
return self.msg.guild
@property
def actor(self):
return self.msg.author
class CommandError(Exception):
pass

5
disco/bot/parser.py

@ -77,6 +77,9 @@ class ArgumentSet(object):
if not arg.required and index + arg.true_count <= len(rawargs):
continue
if arg.count == 0:
raw = rawargs[index:]
else:
raw = rawargs[index:index + arg.true_count]
if arg.types:
@ -88,7 +91,7 @@ class ArgumentSet(object):
r, ', '.join(arg.types)
))
if arg.true_count == 1:
if arg.count == 1:
raw = raw[0]
if not arg.types or arg.types == ['str'] and isinstance(raw, list):

2
disco/bot/plugin.py

@ -61,6 +61,8 @@ class Plugin(LoggingClass, PluginDeco):
def __init__(self, bot, config):
super(Plugin, self).__init__()
self.bot = bot
self.client = bot.client
self.state = bot.client.state
self.config = config
self.listeners = []

24
disco/gateway/events.py

@ -1,8 +1,7 @@
import inflection
import skema
from disco.util import recursive_find_matching
from disco.types.base import BaseType
from disco.util import skema_find_recursive_by_type
from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState
@ -15,8 +14,7 @@ class GatewayEvent(skema.Model):
obj = cls.create(obj['d'])
# TODO: use skema info
for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)):
for item in skema_find_recursive_by_type(obj, skema.ModelType):
item.client = client
return obj
@ -68,13 +66,17 @@ class GuildDelete(GatewayEvent):
class ChannelCreate(Sub('channel')):
channel = skema.ModelType(Channel)
@property
def guild(self):
return self.channel.guild
class ChannelUpdate(Sub('channel')):
channel = skema.ModelType(Channel)
class ChannelUpdate(ChannelCreate):
pass
class ChannelDelete(Sub('channel')):
channel = skema.ModelType(Channel)
class ChannelDelete(ChannelCreate):
pass
class ChannelPinsUpdate(GatewayEvent):
@ -136,8 +138,12 @@ class GuildRoleDelete(GatewayEvent):
class MessageCreate(Sub('message')):
message = skema.ModelType(Message)
@property
def channel(self):
return self.message.channel
class MessageUpdate(Sub('message')):
class MessageUpdate(MessageCreate):
message = skema.ModelType(Message)

74
disco/state.py

@ -1,16 +1,36 @@
from collections import defaultdict, deque, namedtuple
from weakref import WeakValueDictionary
StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])
class StateConfig(object):
# Whether to keep a buffer of messages
track_messages = True
# The number maximum number of messages to store
track_messages_size = 100
class State(object):
def __init__(self, client):
def __init__(self, client, config=None):
self.client = client
self.config = config or StateConfig()
self.me = None
self.channels = {}
self.dms = {}
self.guilds = {}
self.channels = WeakValueDictionary()
self.client.events.on('Ready', self.on_ready)
self.messages_stack = defaultdict(lambda: deque(maxlen=self.config.track_messages_size))
if self.config.track_messages:
self.client.events.on('MessageCreate', self.on_message_create)
self.client.events.on('MessageDelete', self.on_message_delete)
# Guilds
self.client.events.on('GuildCreate', self.on_guild_create)
self.client.events.on('GuildUpdate', self.on_guild_update)
@ -24,27 +44,63 @@ class State(object):
def on_ready(self, event):
self.me = event.user
def on_message_create(self, event):
self.messages_stack[event.message.channel_id].append(
StackMessage(event.message.id, event.message.channel_id, event.message.author.id))
def on_message_update(self, event):
message, cid = event.message, event.message.channel_id
if cid not in self.messages_stack:
return
sm = next((i for i in self.messages_stack[cid] if i.id == message.id), None)
if not sm:
return
sm.id = message.id
sm.channel_id = cid
sm.author_id = message.author.id
def on_message_delete(self, event):
if event.channel_id not in self.messages_stack:
return
sm = next((i for i in self.messages_stack[event.channel_id] if i.id == event.id), None)
if not sm:
return
self.messages_stack[event.channel_id].remove(sm)
def on_guild_create(self, event):
self.guilds[event.guild.id] = event.guild
for channel in event.guild.channels:
self.channels[channel.id] = channel
self.channels.update(event.guild.channels)
def on_guild_update(self, event):
self.guilds[event.guild.id] = event.guild
def on_guild_delete(self, event):
if event.guild_id in self.guilds:
# Just delete the guild, channel references will fall
del self.guilds[event.guild_id]
# CHANNELS?
def on_channel_create(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds:
self.guilds[event.channel.guild_id].channels[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
elif event.channel.is_dm:
self.dms[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
def on_channel_update(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds:
self.guilds[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
elif event.channel.is_dm:
self.dms[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
def on_channel_delete(self, event):
if event.channel.id in self.channels:
del self.channels[event.channel.id]
if event.channel.is_guild and event.channel.guild_id in self.guilds:
del self.guilds[event.channel.id]
elif event.channel.is_dm:
del self.pms[event.channel.id]

10
disco/types/base.py

@ -1,6 +1,7 @@
import skema
import functools
from disco.util import recursive_find_matching
from disco.util import skema_find_recursive_by_type
class BaseType(skema.Model):
@ -11,9 +12,12 @@ class BaseType(skema.Model):
# Valdiate
obj.validate()
# TODO: this can be smarter using skema metadata
for item in recursive_find_matching(obj, lambda v: isinstance(v, BaseType)):
for item in skema_find_recursive_by_type(obj, skema.ModelType):
item.client = client
obj.client = client
return obj
@classmethod
def create_map(cls, client, data):
return map(functools.partial(cls.create, client), data)

87
disco/types/channel.py

@ -3,6 +3,7 @@ import skema
from holster.enum import Enum
from disco.util.cache import cached_property
from disco.util.types import ListToDictType
from disco.types.base import BaseType
from disco.types.user import User
@ -34,18 +35,100 @@ class Channel(BaseType):
name = skema.StringType()
topic = skema.StringType()
last_message_id = skema.SnowflakeType()
_last_message_id = skema.SnowflakeType(stored_name='last_message_id')
position = skema.IntType()
bitrate = skema.IntType(required=False)
recipient = skema.ModelType(User, required=False)
type = skema.IntType(choices=ChannelType.ALL_VALUES)
permission_overwrites = skema.ListType(skema.ModelType(PermissionOverwrite))
overwrites = ListToDictType('id', skema.ModelType(PermissionOverwrite), stored_name='permission_overwrites')
@property
def is_guild(self):
return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE)
@property
def is_dm(self):
return self.type in (ChannelType.DM, ChannelType.GROUP_DM)
@property
def last_message_id(self):
if self.id not in self.client.state.messages_stack:
return self._last_message_id
return self.client.state.messages_stack[self.id][-1].id
@property
def messages(self):
return self.messages_iter()
def messages_iter(self, **kwargs):
return MessageIterator(self.client, self.id, before=self.last_message_id, **kwargs)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
def get_invites(self):
return self.client.api.channels_invites_list(self.id)
def get_pins(self):
return self.client.api.channels_pins_list(self.id)
def send_message(self, content, nonce=None, tts=False):
return self.client.api.channels_messages_create(self.id, content, nonce, tts)
class MessageIterator(object):
Direction = Enum('UP', 'DOWN')
def __init__(self, client, channel, direction=Direction.UP, bulk=False, before=None, after=None, chunk_size=100):
self.client = client
self.channel = channel
self.direction = direction
self.bulk = bulk
self.before = before
self.after = after
self.chunk_size = chunk_size
self.last = None
self._buffer = []
if len(filter(bool, (before, after))) > 1:
raise Exception('Must specify at most one of before or after')
if not any((before, after)) and self.direction == self.Direction.DOWN:
raise Exception('Must specify either before or after for downward seeking')
def fill(self):
self._buffer = self.client.api.channels_messages_list(
self.channel,
before=self.before,
after=self.after,
limit=self.chunk_size)
if not len(self._buffer):
raise StopIteration
self.after = None
self.before = None
if self.direction == self.Direction.UP:
self.before = self._buffer[-1].id
else:
self._buffer.reverse()
self.after == self._buffer[-1].id
def __iter__(self):
return self
def next(self):
if not len(self._buffer):
self.fill()
if self.bulk:
res = self._buffer
self._buffer = []
return res
else:
return self._buffer.pop()

25
disco/types/guild.py

@ -1,8 +1,7 @@
import skema
from disco.util.cache import cached_property
from disco.types.base import BaseType
from disco.util.types import PreHookType
from disco.util.types import PreHookType, ListToDictType
from disco.types.user import User
from disco.types.voice import VoiceState
from disco.types.channel import Channel
@ -33,6 +32,10 @@ class GuildMember(BaseType):
joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())
roles = skema.ListType(skema.SnowflakeType())
@property
def id(self):
return self.user.id
class Guild(BaseType):
id = skema.SnowflakeType()
@ -53,20 +56,16 @@ class Guild(BaseType):
features = skema.ListType(skema.StringType())
members = skema.ListType(skema.ModelType(GuildMember))
voice_states = skema.ListType(skema.ModelType(VoiceState))
channels = skema.ListType(skema.ModelType(Channel))
roles = skema.ListType(skema.ModelType(Role))
emojis = skema.ListType(skema.ModelType(Emoji))
@cached_property
def members_dict(self):
return {i.user.id: i for i in self.members}
members = ListToDictType('id', skema.ModelType(GuildMember))
channels = ListToDictType('id', skema.ModelType(Channel))
roles = ListToDictType('id', skema.ModelType(Role))
emojis = ListToDictType('id', skema.ModelType(Emoji))
voice_states = ListToDictType('id', skema.ModelType(VoiceState))
def get_member(self, user):
return self.members_dict.get(user.id)
return self.members.get(user.id)
def validate_channels(self, ctx):
if self.channels:
for channel in self.channels:
for channel in self.channels.values():
channel.guild_id = self.id

22
disco/types/invite.py

@ -0,0 +1,22 @@
import skema
from disco.util.types import PreHookType
from disco.types.base import BaseType
from disco.types.user import User
from disco.types.guild import Guild
from disco.types.channel import Channel
class Invite(BaseType):
code = skema.StringType()
inviter = skema.ModelType(User)
guild = skema.ModelType(Guild)
channel = skema.ModelType(Channel)
max_age = skema.IntType()
max_uses = skema.IntType()
uses = skema.IntType()
temporary = skema.BooleanType()
created_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())

22
disco/types/message.py

@ -2,7 +2,7 @@ import re
import skema
from disco.util.cache import cached_property
from disco.util.types import PreHookType
from disco.util.types import PreHookType, ListToDictType
from disco.types.base import BaseType
from disco.types.user import User
from disco.types.guild import Role
@ -41,11 +41,11 @@ class Message(BaseType):
pinned = skema.BooleanType(required=False)
mentions = skema.ListType(skema.ModelType(User))
mentions = ListToDictType('id', skema.ModelType(User))
mention_roles = skema.ListType(skema.SnowflakeType())
embeds = skema.ListType(skema.ModelType(MessageEmbed))
attachment = skema.ListType(skema.ModelType(MessageAttachment))
attachments = ListToDictType('id', skema.ModelType(MessageAttachment))
@cached_property
def guild(self):
@ -55,14 +55,6 @@ class Message(BaseType):
def channel(self):
return self.client.state.channels.get(self.channel_id)
@cached_property
def mention_users(self):
return [i.id for i in self.mentions]
@cached_property
def mention_users_dict(self):
return {i.id: i for i in self.mentions}
def reply(self, *args, **kwargs):
return self.channel.send_message(*args, **kwargs)
@ -74,11 +66,13 @@ class Message(BaseType):
def is_mentioned(self, entity):
if isinstance(entity, User):
return entity.id in self.mention_users
return entity.id in self.mentions
elif isinstance(entity, Role):
return entity.id in self.mention_roles
elif isinstance(entity, long):
return entity in self.mentions or entity in self.mention_roles
else:
raise Exception('Unknown entity: {}'.format(entity))
raise Exception('Unknown entity: {} ({})'.format(entity, type(entity)))
@cached_property
def without_mentions(self):
@ -95,6 +89,6 @@ class Message(BaseType):
if id in self.mention_roles:
return role_replace(id)
else:
return user_replace(self.mention_users_dict.get(id))
return user_replace(self.mentions.get(id))
return re.sub('<@!?([0-9]+)>', replace, self.content)

36
disco/util/__init__.py

@ -1,18 +1,36 @@
import skema
def recursive_find_matching(base, match_clause):
def _recurse(typ, field, value):
result = []
if hasattr(base, '__dict__'):
values = base.__dict__.values()
else:
values = list(base)
if isinstance(field, skema.ModelType):
result += skema_find_recursive_by_type(value, typ)
for v in values:
if match_clause(v):
if isinstance(field, (skema.ListType, skema.SetType, skema.DictType)):
if isinstance(field, skema.DictType):
value = value.values()
for item in value:
if isinstance(field.field, typ):
result.append(item)
result += _recurse(typ, field.field, item)
return result
def skema_find_recursive_by_type(base, typ):
result = []
for name, field in base._fields_by_stored_name.items():
v = getattr(base, name, None)
if not v:
continue
if isinstance(field, typ):
result.append(v)
if hasattr(v, '__dict__') or hasattr(v, '__iter__'):
result += recursive_find_matching(v, match_clause)
result += _recurse(typ, field, v)
return result

20
disco/util/types.py

@ -1,4 +1,4 @@
from skema import BaseType
from skema import BaseType, DictType
class PreHookType(BaseType):
@ -16,3 +16,21 @@ class PreHookType(BaseType):
def to_storage(self, *args, **kwargs):
return self.field.to_storage(*args, **kwargs)
class ListToDictType(DictType):
def __init__(self, key, *args, **kwargs):
super(ListToDictType, self).__init__(*args, **kwargs)
self.key = key
def to_python(self, value):
if not value:
return {}
to_python = self.field.to_python
obj = {}
for item in value:
item = to_python(item)
obj[getattr(item, self.key)] = item
return obj

30
examples/basic_plugin.py

@ -19,6 +19,36 @@ class BasicPlugin(Plugin):
for i in range(count):
event.msg.reply(content)
@Plugin.command('invites')
def on_invites(self, event):
invites = event.channel.get_invites()
event.msg.reply('Channel has a total of {} invites'.format(len(invites)))
@Plugin.command('pins')
def on_pins(self, event):
pins = event.channel.get_pins()
event.msg.reply('Channel has a total of {} pins'.format(len(pins)))
@Plugin.command('channel stats')
def on_stats(self, event):
msg = event.msg.reply('Ok, one moment...')
invite_count = len(event.channel.get_invites())
pin_count = len(event.channel.get_pins())
msg_count = 0
print event.channel.messages_iter(bulk=True)
for msgs in event.channel.messages_iter(bulk=True):
msg_count += len(msgs)
msg.edit('{} invites, {} pins, {} messages'.format(invite_count, pin_count, msg_count))
@Plugin.command('messages stack')
def on_messages_stack(self, event):
event.msg.reply('Channels: {}, messages here: ```\n{}\n```'.format(
len(self.state.messages),
'\n'.join([str(i.id) for i in self.state.messages[event.channel.id]])
))
if __name__ == '__main__':
bot = Bot(disco_main())
bot.add_plugin(BasicPlugin)

Loading…
Cancel
Save