Browse Source

Merge branch 'master' into feature/voice

feature/voice
Andrei 8 years ago
parent
commit
027ecbf9e0
  1. 3
      README.md
  2. 2
      disco/__init__.py
  3. 151
      disco/api/client.py
  4. 82
      disco/api/http.py
  5. 21
      disco/api/ratelimit.py
  6. 117
      disco/bot/bot.py
  7. 140
      disco/bot/command.py
  8. 76
      disco/bot/parser.py
  9. 187
      disco/bot/plugin.py
  10. 1
      disco/bot/providers/disk.py
  11. 25
      disco/bot/providers/redis.py
  12. 6
      disco/bot/providers/rocksdb.py
  13. 17
      disco/cli.py
  14. 45
      disco/client.py
  15. 52
      disco/gateway/client.py
  16. 4
      disco/gateway/encoding/base.py
  17. 2
      disco/gateway/encoding/json.py
  18. 424
      disco/gateway/events.py
  19. 91
      disco/gateway/ipc.py
  20. 4
      disco/gateway/packets.py
  21. 104
      disco/gateway/sharder.py
  22. 68
      disco/state.py
  23. 1
      disco/types/__init__.py
  24. 272
      disco/types/base.py
  25. 102
      disco/types/channel.py
  26. 129
      disco/types/guild.py
  27. 6
      disco/types/invite.py
  28. 200
      disco/types/message.py
  29. 12
      disco/types/permissions.py
  30. 40
      disco/types/user.py
  31. 2
      disco/types/voice.py
  32. 6
      disco/types/webhook.py
  33. 2
      disco/util/config.py
  34. 4
      disco/util/hashmap.py
  35. 3
      disco/util/limiter.py
  36. 35
      disco/util/logging.py
  37. 36
      disco/util/serializer.py
  38. 2
      disco/util/snowflake.py
  39. 2
      disco/util/token.py
  40. 7
      disco/voice/client.py
  41. 127
      disco/voice/opus.py
  42. 49
      disco/voice/player.py
  43. 9
      examples/music.py
  44. 2
      requirements.txt

3
README.md

@ -20,6 +20,7 @@ Disco was built to run both as a generic-use library, and a standalone bot toolk
|requests[security]|adds packages for a proper SSL implementation|
|ujson|faster json parser, improves performance|
|erlpack|ETF parser, only Python 2.x, run with the --encoder=etf flag|
|gipc|Gevent IPC, required for autosharding|
## Examples
@ -48,7 +49,7 @@ class SimplePlugin(Plugin):
Using the default bot configuration, we can now run this script like so:
`python -m disco.cli --token="MY_DISCORD_TOKEN" --bot --plugin simpleplugin`
`python -m disco.cli --token="MY_DISCORD_TOKEN" --run-bot --plugin simpleplugin`
And commands can be triggered by mentioning the bot (configued by the BotConfig.command\_require\_mention flag):

2
disco/__init__.py

@ -1 +1 @@
VERSION = '0.0.5'
VERSION = '0.0.7'

151
disco/api/client.py

@ -1,11 +1,12 @@
import six
import json
from disco.api.http import Routes, HTTPClient
from disco.util.logging import LoggingClass
from disco.types.user import User
from disco.types.message import Message
from disco.types.guild import Guild, GuildMember, Role
from disco.types.guild import Guild, GuildMember, GuildBan, Role, GuildEmoji
from disco.types.channel import Channel
from disco.types.invite import Invite
from disco.types.webhook import Webhook
@ -23,18 +24,40 @@ def optional(**kwargs):
class APIClient(LoggingClass):
"""
An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits
the models with the returned data.
An abstraction over a :class:`disco.api.http.HTTPClient`, which composes
requests from provided data, and fits models with the returned data. The APIClient
is the only path to the API used within models/other interfaces, and it's
the recommended path for all third-party users/implementations.
Args
----
token : str
The Discord authentication token (without prefixes) to be used for all
HTTP requests.
client : Optional[:class:`disco.client.Client`]
The Disco client this APIClient is a member of. This is used when constructing
and fitting models from response data.
Attributes
----------
client : Optional[:class:`disco.client.Client`]
The Disco client this APIClient is a member of.
http : :class:`disco.http.HTTPClient`
The HTTPClient this APIClient uses for all requests.
"""
def __init__(self, client):
def __init__(self, token, client=None):
super(APIClient, self).__init__()
self.client = client
self.http = HTTPClient(self.client.config.token)
self.http = HTTPClient(token)
def gateway(self, version, encoding):
def gateway_get(self):
data = self.http(Routes.GATEWAY_GET).json()
return data['url'] + '?v={}&encoding={}'.format(version, encoding)
return data
def gateway_bot_get(self):
data = self.http(Routes.GATEWAY_BOT_GET).json()
return data
def channels_get(self, channel):
r = self.http(Routes.CHANNELS_GET, dict(channel=channel))
@ -48,6 +71,9 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel))
return Channel.create(self.client, r.json())
def channels_typing(self, channel):
self.http(Routes.CHANNELS_TYPING, dict(channel=channel))
def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50):
r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional(
around=around,
@ -62,19 +88,36 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content, nonce=None, tts=False):
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={
def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None):
payload = {
'content': content,
'nonce': nonce,
'tts': tts,
})
}
if embed:
payload['embed'] = embed.to_dict()
if attachment:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data={'payload_json': json.dumps(payload)}, files={
'file': (attachment[0], attachment[1])
})
else:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload)
return Message.create(self.client, r.json())
def channels_messages_modify(self, channel, message, content):
def channels_messages_modify(self, channel, message, content, embed=None):
payload = {
'content': content,
}
if embed:
payload['embed'] = embed.to_dict()
r = self.http(Routes.CHANNELS_MESSAGES_MODIFY,
dict(channel=channel, message=message),
json={'content': content})
dict(channel=channel, message=message),
json=payload)
return Message.create(self.client, r.json())
def channels_messages_delete(self, channel, message):
@ -83,6 +126,23 @@ class APIClient(LoggingClass):
def channels_messages_delete_bulk(self, channel, messages):
self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages})
def channels_messages_reactions_get(self, channel, message, emoji):
r = self.http(Routes.CHANNELS_MESSAGES_REACTIONS_GET, dict(channel=channel, message=message, emoji=emoji))
return User.create_map(self.client, r.json())
def channels_messages_reactions_create(self, channel, message, emoji):
self.http(Routes.CHANNELS_MESSAGES_REACTIONS_CREATE, dict(channel=channel, message=message, emoji=emoji))
def channels_messages_reactions_delete(self, channel, message, emoji, user=None):
route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_ME
obj = dict(channel=channel, message=message, emoji=emoji)
if user:
route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_USER
obj['user'] = user
self.http(route, obj)
def channels_permissions_modify(self, channel, permission, allow, deny, typ):
self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={
'allow': allow,
@ -141,10 +201,28 @@ class APIClient(LoggingClass):
def guilds_channels_list(self, guild):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_map(self.client, r.json(), guild_id=guild)
def guilds_channels_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs)
return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[]):
payload = {
'name': name,
'channel_type': channel_type,
'permission_overwrites': [i.to_dict() for i in permission_overwrites],
}
if channel_type == 'text':
pass
elif channel_type == 'voice':
if bitrate is not None:
payload['bitrate'] = bitrate
if user_limit is not None:
payload['user_limit'] = user_limit
else:
# TODO: better error here?
raise Exception('Invalid channel type: {}'.format(channel_type))
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=payload)
return Channel.create(self.client, r.json(), guild_id=guild)
def guilds_channels_modify(self, guild, channel, position):
@ -155,21 +233,30 @@ class APIClient(LoggingClass):
def guilds_members_list(self, guild):
r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild))
return GuildMember.create_map(self.client, r.json(), guild_id=guild)
return GuildMember.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_members_get(self, guild, member):
r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member))
return GuildMember.create(self.client, r.json(), guild_id=guild)
def guilds_members_modify(self, guild, member, **kwargs):
self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs)
self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=optional(**kwargs))
def guilds_members_roles_add(self, guild, member, role):
self.http(Routes.GUILDS_MEMBERS_ROLES_ADD, dict(guild=guild, member=member, role=role))
def guilds_members_roles_remove(self, guild, member, role):
self.http(Routes.GUILDS_MEMBERS_ROLES_REMOVE, dict(guild=guild, member=member, role=role))
def guilds_members_me_nick(self, guild, nick):
self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick})
def guilds_members_kick(self, guild, member):
self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member))
def guilds_bans_list(self, guild):
r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild))
return User.create_map(self.client, r.json())
return GuildBan.create_hash(self.client, 'user.id', r.json())
def guilds_bans_create(self, guild, user, delete_message_days):
self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={
@ -202,6 +289,28 @@ class APIClient(LoggingClass):
r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild))
return Webhook.create_map(self.client, r.json())
def guilds_emojis_list(self, guild):
r = self.http(Routes.GUILDS_EMOJIS_LIST, dict(guild=guild))
return GuildEmoji.create_map(self.client, r.json())
def guilds_emojis_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs)
return GuildEmoji.create(self.client, r.json())
def guilds_emojis_modify(self, guild, emoji, **kwargs):
r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs)
return GuildEmoji.create(self.client, r.json())
def guilds_emojis_delete(self, guild, emoji):
self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji))
def users_me_get(self):
return User.create(self.client, self.http(Routes.USERS_ME_GET).json())
def users_me_patch(self, payload):
r = self.http(Routes.USERS_ME_PATCH, json=payload)
return User.create(self.client, r.json())
def invites_get(self, invite):
r = self.http(Routes.INVITES_GET, dict(invite=invite))
return Invite.create(self.client, r.json())
@ -236,7 +345,7 @@ class APIClient(LoggingClass):
return Webhook.create(self.client, r.json())
def webhooks_token_delete(self, webhook, token):
self.http(Routes.WEBHOOKS_TOKEN_DLEETE, dict(webhook=webhook, token=token))
self.http(Routes.WEBHOOKS_TOKEN_DELETE, dict(webhook=webhook, token=token))
def webhooks_token_execute(self, webhook, token, data, wait=False):
obj = self.http(

82
disco/api/http.py

@ -2,9 +2,12 @@ import requests
import random
import gevent
import six
import sys
from holster.enum import Enum
from disco import VERSION as disco_version
from requests import __version__ as requests_version
from disco.util.logging import LoggingClass
from disco.api.ratelimit import RateLimiter
@ -18,6 +21,12 @@ HTTPMethod = Enum(
)
def to_bytes(obj):
if isinstance(obj, six.text_type):
return obj.encode('utf-8')
return obj
class Routes(object):
"""
Simple Python object-enum of all method/url route combinations available to
@ -25,18 +34,25 @@ class Routes(object):
"""
# Gateway
GATEWAY_GET = (HTTPMethod.GET, '/gateway')
GATEWAY_BOT_GET = (HTTPMethod.GET, '/gateway/bot')
# Channels
CHANNELS = '/channels/{channel}'
CHANNELS_GET = (HTTPMethod.GET, CHANNELS)
CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS)
CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS)
CHANNELS_TYPING = (HTTPMethod.POST, CHANNELS + '/typing')
CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages')
CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages')
CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}')
CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete')
CHANNELS_MESSAGES_REACTIONS_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}/reactions/{emoji}')
CHANNELS_MESSAGES_REACTIONS_CREATE = (HTTPMethod.PUT, CHANNELS + '/messages/{message}/reactions/{emoji}/@me')
CHANNELS_MESSAGES_REACTIONS_DELETE_ME = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/@me')
CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE,
CHANNELS + '/messages/{message}/reactions/{emoji}/{user}')
CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}')
CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}')
CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites')
@ -58,6 +74,9 @@ class Routes(object):
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}')
GUILDS_MEMBERS_ROLES_ADD = (HTTPMethod.PUT, GUILDS + '/members/{member}/roles/{role}')
GUILDS_MEMBERS_ROLES_REMOVE = (HTTPMethod.DELETE, GUILDS + '/members/{member}/roles/{role}')
GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}')
GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}')
@ -79,6 +98,10 @@ class Routes(object):
GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed')
GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed')
GUILDS_WEBHOOKS_LIST = (HTTPMethod.GET, GUILDS + '/webhooks')
GUILDS_EMOJIS_LIST = (HTTPMethod.GET, GUILDS + '/emojis')
GUILDS_EMOJIS_CREATE = (HTTPMethod.POST, GUILDS + '/emojis')
GUILDS_EMOJIS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/emojis/{emoji}')
GUILDS_EMOJIS_DELETE = (HTTPMethod.DELETE, GUILDS + '/emojis/{emoji}')
# Users
USERS = '/users'
@ -111,14 +134,39 @@ class APIException(Exception):
"""
Exception thrown when an HTTP-client level error occurs. Usually this will
be a non-success status-code, or a transient network issue.
Attributes
----------
status_code : int
The status code returned by the API for the request that triggered this
error.
"""
def __init__(self, msg, status_code=0, content=None):
self.status_code = status_code
self.content = content
self.msg = msg
def __init__(self, response, retries=None):
self.response = response
self.retries = retries
self.code = 0
self.msg = 'Request Failed ({})'.format(response.status_code)
if self.status_code:
self.msg += ' code: {}'.format(status_code)
if self.retries:
self.msg += " after {} retries".format(self.retries)
# Try to decode JSON, and extract params
try:
data = self.response.json()
if 'code' in data:
self.code = data['code']
self.msg = data['message']
elif len(data) == 1:
key, value = list(data.items())[0]
self.msg = 'Request Failed: {}: {}'.format(key, ', '.join(value))
except ValueError:
pass
# DEPRECATED: left for backwards compat
self.status_code = response.status_code
self.content = response.content
super(APIException, self).__init__(self.msg)
@ -134,9 +182,18 @@ class HTTPClient(LoggingClass):
def __init__(self, token):
super(HTTPClient, self).__init__()
py_version = '{}.{}.{}'.format(
sys.version_info.major,
sys.version_info.minor,
sys.version_info.micro)
self.limiter = RateLimiter()
self.headers = {
'Authorization': 'Bot ' + token,
'User-Agent': 'DiscordBot (https://github.com/b1naryth1ef/disco {}) Python/{} requests/{}'.format(
disco_version,
py_version,
requests_version),
}
def __call__(self, route, args=None, **kwargs):
@ -182,7 +239,8 @@ class HTTPClient(LoggingClass):
kwargs['headers'] = self.headers
# Build the bucket URL
filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in six.iteritems(args)}
args = {k: to_bytes(v) for k, v in six.iteritems(args)}
filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)}
bucket = (route[0].value, route[1].format(**filtered))
# Possibly wait if we're rate limited
@ -190,6 +248,7 @@ class HTTPClient(LoggingClass):
# Make the actual request
url = self.BASE_URL + route[1].format(**args)
self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params'))
r = requests.request(route[0].value, url, **kwargs)
# Update rate limiter
@ -198,17 +257,18 @@ class HTTPClient(LoggingClass):
# If we got a success status code, just return the data
if r.status_code < 400:
return r
elif r.status_code != 429 and 400 < r.status_code < 500:
raise APIException('Request failed', r.status_code, r.content)
elif r.status_code != 429 and 400 <= r.status_code < 500:
raise APIException(r)
else:
if r.status_code == 429:
self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync')
self.log.warning(
'Request responded w/ 429, retrying (but this should not happen, check your clock sync')
# If we hit the max retries, throw an error
retry += 1
if retry > self.MAX_RETRIES:
self.log.error('Failing request, hit max retries')
raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content)
raise APIException(r, retries=self.MAX_RETRIES)
backoff = self.random_backoff()
self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format(

21
disco/api/ratelimit.py

@ -1,8 +1,10 @@
import time
import gevent
from disco.util.logging import LoggingClass
class RouteState(object):
class RouteState(LoggingClass):
"""
An object which stores ratelimit state for a given method/url route
combination (as specified in :class:`disco.api.http.Routes`).
@ -36,10 +38,13 @@ class RouteState(object):
self.update(response)
def __repr__(self):
return '<RouteState {}>'.format(' '.join(self.route))
@property
def chilled(self):
"""
Whether this route is currently being cooldown (aka waiting until reset_time)
Whether this route is currently being cooldown (aka waiting until reset_time).
"""
return self.event is not None
@ -69,7 +74,7 @@ class RouteState(object):
def wait(self, timeout=None):
"""
Waits until this route is no longer under a cooldown
Waits until this route is no longer under a cooldown.
Parameters
----------
@ -80,24 +85,26 @@ class RouteState(object):
Returns
-------
bool
False if the timeout period expired before the cooldown was finished
False if the timeout period expired before the cooldown was finished.
"""
return self.event.wait(timeout)
def cooldown(self):
"""
Waits for the current route to be cooled-down (aka waiting until reset time)
Waits for the current route to be cooled-down (aka waiting until reset time).
"""
if self.reset_time - time.time() < 0:
raise Exception('Cannot cooldown for negative time period; check clock sync')
self.event = gevent.event.Event()
gevent.sleep((self.reset_time - time.time()) + .5)
delay = (self.reset_time - time.time()) + .5
self.log.debug('Cooling down bucket %s for %s seconds', self, delay)
gevent.sleep(delay)
self.event.set()
self.event = None
class RateLimiter(object):
class RateLimiter(LoggingClass):
"""
A in-memory store of ratelimit states for all routes we've ever called.

117
disco/bot/bot.py

@ -12,6 +12,7 @@ from disco.bot.plugin import Plugin
from disco.bot.command import CommandEvent, CommandLevels
from disco.bot.storage import Storage
from disco.util.config import Config
from disco.util.logging import LoggingClass
from disco.util.serializer import Serializer
@ -64,7 +65,7 @@ class BotConfig(Config):
The directory plugin configuration is located within.
"""
levels = {}
plugins = []
plugin_config = {}
commands_enabled = True
commands_require_mention = True
@ -88,7 +89,7 @@ class BotConfig(Config):
storage_config = {}
class Bot(object):
class Bot(LoggingClass):
"""
Disco's implementation of a simple but extendable Discord bot. Bots consist
of a set of plugins, and a Disco client.
@ -114,6 +115,9 @@ class Bot(object):
self.client = client
self.config = config or BotConfig()
# Shard manager
self.shards = None
# The context carries information about events in a threadlocal storage
self.ctx = ThreadLocal()
@ -122,6 +126,7 @@ class Bot(object):
if self.config.storage_enabled:
self.storage = Storage(self.ctx, self.config.from_prefix('storage'))
# If the manhole is enabled, add this bot as a local
if self.client.config.manhole_enable:
self.client.manhole_locals['bot'] = self
@ -135,6 +140,12 @@ class Bot(object):
if self.config.commands_allow_edit:
self.client.events.on('MessageUpdate', self.on_message_update)
# If we have a level getter and its a string, try to load it
if isinstance(self.config.commands_level_getter, six.string_types):
mod, func = self.config.commands_level_getter.rsplit('.', 1)
mod = importlib.import_module(mod)
self.config.commands_level_getter = getattr(mod, func)
# Stores the last message for every single channel
self.last_message_cache = {}
@ -173,10 +184,10 @@ class Bot(object):
@property
def commands(self):
"""
Generator of all commands this bots plugins have defined
Generator of all commands this bots plugins have defined.
"""
for plugin in six.itervalues(self.plugins):
for command in six.itervalues(plugin.commands):
for command in plugin.commands:
yield command
def recompute(self):
@ -190,7 +201,7 @@ class Bot(object):
def compute_group_abbrev(self):
"""
Computes all possible abbreviations for a command grouping
Computes all possible abbreviations for a command grouping.
"""
self.group_abbrev = {}
groups = set(command.group for command in self.commands if command.group)
@ -199,7 +210,7 @@ class Bot(object):
grp = group
while grp:
# If the group already exists, means someone else thought they
# could use it so we need to
# could use it so we need yank it from them (and not use it)
if grp in list(six.itervalues(self.group_abbrev)):
self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp}
else:
@ -211,13 +222,14 @@ class Bot(object):
"""
Computes a single regex which matches all possible command combinations.
"""
re_str = '|'.join(command.regex for command in self.commands)
commands = list(self.commands)
re_str = '|'.join(command.regex for command in commands)
if re_str:
self.command_matches_re = re.compile(re_str)
self.command_matches_re = re.compile(re_str, re.I)
else:
self.command_matches_re = None
def get_commands_for_message(self, msg):
def get_commands_for_message(self, require_mention, mention_rules, prefix, msg):
"""
Generator of all commands that a given message object triggers, based on
the bots plugins and configuration.
@ -234,19 +246,19 @@ class Bot(object):
"""
content = msg.content
if self.config.commands_require_mention:
if require_mention:
mention_direct = msg.is_mentioned(self.client.state.me)
mention_everyone = msg.mention_everyone
mention_roles = []
if msg.guild:
mention_roles = list(filter(lambda r: msg.is_mentioned(r),
msg.guild.get_member(self.client.state.me).roles))
msg.guild.get_member(self.client.state.me).roles))
if not any((
self.config.commands_mention_rules['user'] and mention_direct,
self.config.commands_mention_rules['everyone'] and mention_everyone,
self.config.commands_mention_rules['role'] and any(mention_roles),
mention_rules.get('user', True) and mention_direct,
mention_rules.get('everyone', False) and mention_everyone,
mention_rules.get('role', False) and any(mention_roles),
msg.channel.is_dm
)):
raise StopIteration
@ -262,14 +274,14 @@ class Bot(object):
content = content.replace('@everyone', '', 1)
else:
for role in mention_roles:
content = content.replace(role.mention, '', 1)
content = content.replace('<@{}>'.format(role), '', 1)
content = content.lstrip()
if self.config.commands_prefix and not content.startswith(self.config.commands_prefix):
if prefix and not content.startswith(prefix):
raise StopIteration
else:
content = content[len(self.config.commands_prefix):]
content = content[len(prefix):]
if not self.command_matches_re or not self.command_matches_re.match(content):
raise StopIteration
@ -283,7 +295,7 @@ class Bot(object):
level = CommandLevels.DEFAULT
if callable(self.config.commands_level_getter):
level = self.config.commands_level_getter(actor)
level = self.config.commands_level_getter(self, actor)
else:
if actor.id in self.config.levels:
level = self.config.levels[actor.id]
@ -320,19 +332,24 @@ class Bot(object):
bool
whether any commands where successfully triggered by the message
"""
commands = list(self.get_commands_for_message(msg))
commands = list(self.get_commands_for_message(
self.config.commands_require_mention,
self.config.commands_mention_rules,
self.config.commands_prefix,
msg
))
if len(commands):
result = False
for command, match in commands:
if not self.check_command_permissions(command, msg):
continue
if not len(commands):
return False
if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
result = False
for command, match in commands:
if not self.check_command_permissions(command, msg):
continue
return False
if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
def on_message_create(self, event):
if event.message.author.id == self.client.state.me.id:
@ -356,7 +373,7 @@ class Bot(object):
self.last_message_cache[msg.channel_id] = (msg, triggered)
def add_plugin(self, cls, config=None):
def add_plugin(self, cls, config=None, ctx=None):
"""
Adds and loads a plugin, based on its class.
@ -366,8 +383,12 @@ class Bot(object):
Plugin class to initialize and load.
config : Optional
The configuration to load the plugin with.
ctx : Optional[dict]
Context (previous state) to pass the plugin. Usually used along w/
unload.
"""
if cls.__name__ in self.plugins:
self.log.warning('Attempted to add already added plugin %s', cls.__name__)
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))
if not config:
@ -376,9 +397,10 @@ class Bot(object):
else:
config = self.load_plugin_config(cls)
self.plugins[cls.__name__] = cls(self, config)
self.plugins[cls.__name__].load()
self.ctx['plugin'] = self.plugins[cls.__name__] = cls(self, config)
self.plugins[cls.__name__].load(ctx or {})
self.recompute()
self.ctx.drop()
def rmv_plugin(self, cls):
"""
@ -392,9 +414,11 @@ class Bot(object):
if cls.__name__ not in self.plugins:
raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__))
self.plugins[cls.__name__].unload()
ctx = {}
self.plugins[cls.__name__].unload(ctx)
del self.plugins[cls.__name__]
self.recompute()
return ctx
def reload_plugin(self, cls):
"""
@ -402,13 +426,13 @@ class Bot(object):
"""
config = self.plugins[cls.__name__].config
self.rmv_plugin(cls)
ctx = self.rmv_plugin(cls)
module = reload_module(inspect.getmodule(cls))
self.add_plugin(getattr(module, cls.__name__), config)
self.add_plugin(getattr(module, cls.__name__), config, ctx)
def run_forever(self):
"""
Runs this bots core loop forever
Runs this bots core loop forever.
"""
self.client.run_forever()
@ -416,12 +440,14 @@ class Bot(object):
"""
Adds and loads a plugin, based on its module path.
"""
self.log.info('Adding plugin module at path "%s"', path)
mod = importlib.import_module(path)
loaded = False
for entry in map(lambda i: getattr(mod, i), dir(mod)):
if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin:
if getattr(entry, '_shallow', False) and Plugin in entry.__bases__:
continue
loaded = True
self.add_plugin(entry, config)
@ -430,23 +456,24 @@ class Bot(object):
def load_plugin_config(self, cls):
name = cls.__name__.lower()
if name.startswith('plugin'):
name = name[6:]
if name.endswith('plugin'):
name = name[:-6]
path = os.path.join(
self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format
if not os.path.exists(path):
if hasattr(cls, 'config_cls'):
return cls.config_cls()
return
data = {}
if name in self.config.plugin_config:
data = self.config.plugin_config[name]
with open(path, 'r') as f:
data = Serializer.loads(self.config.plugin_config_format, f.read())
if os.path.exists(path):
with open(path, 'r') as f:
data.update(Serializer.loads(self.config.plugin_config_format, f.read()))
if hasattr(cls, 'config_cls'):
inst = cls.config_cls()
inst.update(data)
if data:
inst.update(data)
return inst
return data

140
disco/bot/command.py

@ -5,9 +5,11 @@ from holster.enum import Enum
from disco.bot.parser import ArgumentSet, ArgumentError
from disco.util.functional import cached_property
REGEX_FMT = '({})'
ARGS_REGEX = '( (.*)$|$)'
MENTION_RE = re.compile('<@!?([0-9]+)>')
ARGS_REGEX = '(?: ((?:\n|.)*)$|$)'
USER_MENTION_RE = re.compile('<@!?([0-9]+)>')
ROLE_MENTION_RE = re.compile('<@&([0-9]+)>')
CHANNEL_MENTION_RE = re.compile('<#([0-9]+)>')
CommandLevels = Enum(
DEFAULT=0,
@ -42,34 +44,52 @@ class CommandEvent(object):
self.command = command
self.msg = msg
self.match = match
self.name = self.match.group(1)
self.args = [i for i in self.match.group(2).strip().split(' ') if i]
self.name = self.match.group(0)
self.args = []
if self.match.group(1):
self.args = [i for i in self.match.group(1).strip().split(' ') if i]
@property
def codeblock(self):
if '`' not in self.msg.content:
return ' '.join(self.args)
_, src = self.msg.content.split('`', 1)
src = '`' + src
if src.startswith('```') and src.endswith('```'):
src = src[3:-3]
elif src.startswith('`') and src.endswith('`'):
src = src[1:-1]
return src
@cached_property
def member(self):
"""
Guild member (if relevant) for the user that created the message
Guild member (if relevant) for the user that created the message.
"""
return self.guild.get_member(self.author)
@property
def channel(self):
"""
Channel the message was created in
Channel the message was created in.
"""
return self.msg.channel
@property
def guild(self):
"""
Guild (if relevant) the message was created in
Guild (if relevant) the message was created in.
"""
return self.msg.guild
@property
def author(self):
"""
Author of the message
Author of the message.
"""
return self.msg.author
@ -107,61 +127,106 @@ class Command(object):
self.plugin = plugin
self.func = func
self.triggers = [trigger]
self.dispatch_func = None
self.raw_args = None
self.args = None
self.level = None
self.group = None
self.is_regex = None
self.oob = False
self.context = {}
self.metadata = {}
self.update(*args, **kwargs)
def update(self, args=None, level=None, aliases=None, group=None, is_regex=None):
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def get_docstring(self):
return (self.func.__doc__ or '').format(**self.context)
def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None, **kwargs):
self.triggers += aliases or []
def resolve_role(ctx, id):
return ctx.msg.guild.roles.get(id)
def resolve_role(ctx, rid):
return ctx.msg.guild.roles.get(rid)
def resolve_user(ctx, uid):
if isinstance(uid, int):
if uid in ctx.msg.mentions:
return ctx.msg.mentions.get(uid)
else:
return ctx.msg.client.state.users.get(uid)
else:
return ctx.msg.client.state.users.select_one(username=uid[0], discriminator=uid[1])
def resolve_channel(ctx, cid):
if isinstance(cid, (int, long)):
return ctx.msg.guild.channels.get(cid)
else:
return ctx.msg.guild.channels.select_one(name=cid)
def resolve_user(ctx, id):
return ctx.msg.mentions.get(id)
def resolve_guild(ctx, gid):
return ctx.msg.client.state.guilds.get(gid)
self.raw_args = args
self.args = ArgumentSet.from_string(args or '', {
'mention': self.mention_type([resolve_role, resolve_user]),
'user': self.mention_type([resolve_user], force=True),
'role': self.mention_type([resolve_role], force=True),
'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True),
'role': self.mention_type([resolve_role], ROLE_MENTION_RE),
'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE, allow_plain=True),
'guild': self.mention_type([resolve_guild]),
})
self.level = level
self.group = group
self.is_regex = is_regex
self.oob = oob
self.context = context or {}
self.metadata = kwargs
@staticmethod
def mention_type(getters, force=False):
def _f(ctx, i):
res = MENTION_RE.match(i)
if not res:
raise TypeError('Invalid mention: {}'.format(i))
id = int(res.group(1))
def mention_type(getters, reg=None, user=False, allow_plain=False):
def _f(ctx, raw):
if raw.isdigit():
resolved = int(raw)
elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit():
username, discrim = raw.split('#')
resolved = (username, int(discrim))
elif reg:
res = reg.match(raw)
if res:
resolved = int(res.group(1))
else:
if allow_plain:
resolved = raw
else:
raise TypeError('Invalid mention: {}'.format(raw))
else:
raise TypeError('Invalid mention: {}'.format(raw))
for getter in getters:
obj = getter(ctx, id)
obj = getter(ctx, resolved)
if obj:
return obj
if force:
raise TypeError('Cannot resolve mention: {}'.format(id))
return id
raise TypeError('Cannot resolve mention: {}'.format(raw))
return _f
@cached_property
def compiled_regex(self):
"""
A compiled version of this command's regex
A compiled version of this command's regex.
"""
return re.compile(self.regex)
return re.compile(self.regex, re.I)
@property
def regex(self):
"""
The regex string that defines/triggers this command
The regex string that defines/triggers this command.
"""
if self.is_regex:
return REGEX_FMT.format('|'.join(self.triggers))
return '|'.join(self.triggers)
else:
group = ''
if self.group:
@ -169,7 +234,7 @@ class Command(object):
group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group))
else:
group = self.group + ' '
return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX)
return '^{}(?:{})'.format(group, '|'.join(self.triggers)) + ARGS_REGEX
def execute(self, event):
"""
@ -189,8 +254,11 @@ class Command(object):
))
try:
args = self.args.parse(event.args, ctx=event)
parsed_args = self.args.parse(event.args, ctx=event)
except ArgumentError as e:
raise CommandError(e.message)
return self.func(event, *args)
kwargs = {}
kwargs.update(self.context)
kwargs.update(parsed_args)
return self.plugin.dispatch('command', self, event, **kwargs)

76
disco/bot/parser.py

@ -2,9 +2,10 @@ import re
import six
import copy
# Regex which splits out argument parts
PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])')
PARTS_RE = re.compile('(\<|\[|\{)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\]|\})')
BOOL_OPTS = {'yes': True, 'no': False, 'true': True, 'False': False, '1': True, '0': False}
# Mapping of types
TYPE_MAP = {
@ -14,6 +15,20 @@ TYPE_MAP = {
'snowflake': lambda ctx, data: int(data),
}
try:
import dateparser
TYPE_MAP['duration'] = lambda ctx, data: dateparser.parse(data, settings={'TIMEZONE': 'UTC'})
except ImportError:
pass
def to_bool(ctx, data):
if data in BOOL_OPTS:
return BOOL_OPTS[data]
raise TypeError
TYPE_MAP['bool'] = to_bool
class ArgumentError(Exception):
"""
@ -41,19 +56,20 @@ class Argument(object):
self.name = None
self.count = 1
self.required = False
self.flag = False
self.types = None
self.parse(raw)
@property
def true_count(self):
"""
The true number of raw arguments this argument takes
The true number of raw arguments this argument takes.
"""
return self.count or 1
def parse(self, raw):
"""
Attempts to parse arguments from their raw form
Attempts to parse arguments from their raw form.
"""
prefix, part = raw
@ -62,23 +78,27 @@ class Argument(object):
else:
self.required = False
if part.endswith('...'):
part = part[:-3]
self.count = 0
elif ' ' in part:
split = part.split(' ', 1)
part, self.count = split[0], int(split[1])
# Whether this is a flag
self.flag = (prefix == '{')
if not self.flag:
if part.endswith('...'):
part = part[:-3]
self.count = 0
elif ' ' in part:
split = part.split(' ', 1)
part, self.count = split[0], int(split[1])
if ':' in part:
part, typeinfo = part.split(':')
self.types = typeinfo.split('|')
if ':' in part:
part, typeinfo = part.split(':')
self.types = typeinfo.split('|')
self.name = part.strip()
class ArgumentSet(object):
"""
A set of :class:`Argument` instances which forms a larger argument specification
A set of :class:`Argument` instances which forms a larger argument specification.
Attributes
----------
@ -95,7 +115,7 @@ class ArgumentSet(object):
@classmethod
def from_string(cls, line, custom_types=None):
"""
Creates a new :class:`ArgumentSet` from a given argument string specification
Creates a new :class:`ArgumentSet` from a given argument string specification.
"""
args = cls(custom_types=custom_types)
@ -131,7 +151,7 @@ class ArgumentSet(object):
def append(self, arg):
"""
Add a new :class:`Argument` to this argument specification/set
Add a new :class:`Argument` to this argument specification/set.
"""
if self.args and not self.args[-1].required and arg.required:
raise Exception('Required argument cannot come after an optional argument')
@ -145,9 +165,23 @@ class ArgumentSet(object):
"""
Parse a string of raw arguments into this argument specification.
"""
parsed = []
parsed = {}
flags = {i.name: i for i in self.args if i.flag}
if flags:
new_rawargs = []
for offset, raw in enumerate(rawargs):
if raw.startswith('-'):
raw = raw.lstrip('-')
if raw in flags:
parsed[raw] = True
continue
new_rawargs.append(raw)
rawargs = new_rawargs
for index, arg in enumerate(self.args):
for index, arg in enumerate((arg for arg in self.args if not arg.flag)):
if not arg.required and index + arg.true_count > len(rawargs):
continue
@ -171,20 +205,20 @@ class ArgumentSet(object):
if (not arg.types or arg.types == ['str']) and isinstance(raw, list):
raw = ' '.join(raw)
parsed.append(raw)
parsed[arg.name] = raw
return parsed
@property
def length(self):
"""
The number of arguments in this set/specification
The number of arguments in this set/specification.
"""
return len(self.args)
@property
def required_length(self):
"""
The number of required arguments to compile this set/specificaiton
The number of required arguments to compile this set/specificaiton.
"""
return sum([i.true_count for i in self.args if i.required])

187
disco/bot/plugin.py

@ -1,9 +1,11 @@
import six
import types
import gevent
import inspect
import weakref
import functools
from gevent.event import AsyncResult
from holster.emitter import Priority
from disco.util.logging import LoggingClass
@ -18,8 +20,8 @@ class PluginDeco(object):
Prio = Priority
# TODO: dont smash class methods
@staticmethod
def add_meta_deco(meta):
@classmethod
def add_meta_deco(cls, meta):
def deco(f):
if not hasattr(f, 'meta'):
f.meta = []
@ -40,33 +42,33 @@ class PluginDeco(object):
return deco
@classmethod
def listen(cls, event_name, priority=None):
def listen(cls, *args, **kwargs):
"""
Binds the function to listen for a given event name
Binds the function to listen for a given event name.
"""
return cls.add_meta_deco({
'type': 'listener',
'what': 'event',
'desc': event_name,
'priority': priority
'args': args,
'kwargs': kwargs,
})
@classmethod
def listen_packet(cls, op, priority=None):
def listen_packet(cls, *args, **kwargs):
"""
Binds the function to listen for a given gateway op code
Binds the function to listen for a given gateway op code.
"""
return cls.add_meta_deco({
'type': 'listener',
'what': 'packet',
'desc': op,
'priority': priority,
'args': args,
'kwargs': kwargs,
})
@classmethod
def command(cls, *args, **kwargs):
"""
Creates a new command attached to the function
Creates a new command attached to the function.
"""
return cls.add_meta_deco({
'type': 'command',
@ -77,7 +79,7 @@ class PluginDeco(object):
@classmethod
def pre_command(cls):
"""
Runs a function before a command is triggered
Runs a function before a command is triggered.
"""
return cls.add_meta_deco({
'type': 'pre_command',
@ -86,7 +88,7 @@ class PluginDeco(object):
@classmethod
def post_command(cls):
"""
Runs a function after a command is triggered
Runs a function after a command is triggered.
"""
return cls.add_meta_deco({
'type': 'post_command',
@ -95,7 +97,7 @@ class PluginDeco(object):
@classmethod
def pre_listener(cls):
"""
Runs a function before a listener is triggered
Runs a function before a listener is triggered.
"""
return cls.add_meta_deco({
'type': 'pre_listener',
@ -104,7 +106,7 @@ class PluginDeco(object):
@classmethod
def post_listener(cls):
"""
Runs a function after a listener is triggered
Runs a function after a listener is triggered.
"""
return cls.add_meta_deco({
'type': 'post_listener',
@ -113,7 +115,7 @@ class PluginDeco(object):
@classmethod
def schedule(cls, *args, **kwargs):
"""
Runs a function repeatedly, waiting for a specified interval
Runs a function repeatedly, waiting for a specified interval.
"""
return cls.add_meta_deco({
'type': 'schedule',
@ -153,46 +155,101 @@ class Plugin(LoggingClass, PluginDeco):
self.storage = bot.storage
self.config = config
# General declartions
self.listeners = []
self.commands = []
self.schedules = {}
self.greenlets = weakref.WeakSet()
self._pre = {}
self._post = {}
# This is an array of all meta functions we sniff at init
self.meta_funcs = []
for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'):
self.meta_funcs.append(member)
# Unsmash local functions
if hasattr(Plugin, name):
method = types.MethodType(getattr(Plugin, name), self, self.__class__)
setattr(self, name, method)
self.bind_all()
@property
def name(self):
return self.__class__.__name__
def bind_all(self):
self.listeners = []
self.commands = {}
self.commands = []
self.schedules = {}
self.greenlets = weakref.WeakSet()
self._pre = {'command': [], 'listener': []}
self._post = {'command': [], 'listener': []}
# TODO: when handling events/commands we need to track the greenlet in
# the greenlets set so we can termiante long running commands/listeners
# on reload.
for member in self.meta_funcs:
for meta in member.meta:
self.bind_meta(member, meta)
def bind_meta(self, member, meta):
if meta['type'] == 'listener':
self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs'])
elif meta['type'] == 'command':
# meta['kwargs']['update'] = True
self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'] == 'schedule':
self.register_schedule(member, *meta['args'], **meta['kwargs'])
elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'):
when, typ = meta['type'].split('_', 1)
self.register_trigger(typ, when, member)
def handle_exception(self, greenlet, event):
pass
def wait_for_event(self, event_name, **kwargs):
result = AsyncResult()
listener = None
def _event_callback(event):
for k, v in kwargs.items():
if getattr(event, k) != v:
break
else:
listener.remove()
return result.set(event)
for name, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, 'meta'):
for meta in member.meta:
if meta['type'] == 'listener':
self.register_listener(member, meta['what'], meta['desc'], meta['priority'])
elif meta['type'] == 'command':
meta['kwargs']['update'] = True
self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'] == 'schedule':
self.register_schedule(member, *meta['args'], **meta['kwargs'])
elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'):
when, typ = meta['type'].split('_', 1)
self.register_trigger(typ, when, member)
def spawn(self, method, *args, **kwargs):
obj = gevent.spawn(method, *args, **kwargs)
listener = self.bot.client.events.on(event_name, _event_callback)
return result
def spawn_wrap(self, spawner, method, *args, **kwargs):
def wrapped(*args, **kwargs):
self.ctx['plugin'] = self
try:
res = method(*args, **kwargs)
return res
finally:
self.ctx.drop()
obj = spawner(wrapped, *args, **kwargs)
self.greenlets.add(obj)
return obj
def spawn(self, *args, **kwargs):
return self.spawn_wrap(gevent.spawn, *args, **kwargs)
def spawn_later(self, delay, *args, **kwargs):
return self.spawn_wrap(functools.partial(gevent.spawn_later, delay), *args, **kwargs)
def execute(self, event):
"""
Executes a CommandEvent this plugin owns
Executes a CommandEvent this plugin owns.
"""
if not event.command.oob:
self.greenlets.add(gevent.getcurrent())
try:
return event.command.execute(event)
except CommandError as e:
@ -203,11 +260,18 @@ class Plugin(LoggingClass, PluginDeco):
def register_trigger(self, typ, when, func):
"""
Registers a trigger
Registers a trigger.
"""
getattr(self, '_' + when)[typ].append(func)
def _dispatch(self, typ, func, event, *args, **kwargs):
def dispatch(self, typ, func, event, *args, **kwargs):
# Link the greenlet with our exception handler
gevent.getcurrent().link_exception(lambda g: self.handle_exception(g, event))
# TODO: this is ugly
if typ != 'command':
self.greenlets.add(gevent.getcurrent())
self.ctx['plugin'] = self
if hasattr(event, 'guild'):
@ -218,7 +282,7 @@ class Plugin(LoggingClass, PluginDeco):
self.ctx['user'] = event.author
for pre in self._pre[typ]:
event = pre(event, args, kwargs)
event = pre(func, event, args, kwargs)
if event is None:
return False
@ -226,13 +290,13 @@ class Plugin(LoggingClass, PluginDeco):
result = func(event, *args, **kwargs)
for post in self._post[typ]:
post(event, args, kwargs, result)
post(func, event, args, kwargs, result)
return True
def register_listener(self, func, what, desc, priority):
def register_listener(self, func, what, *args, **kwargs):
"""
Registers a listener
Registers a listener.
Parameters
----------
@ -242,17 +306,13 @@ class Plugin(LoggingClass, PluginDeco):
The function to be registered.
desc
The descriptor of the event/packet.
priority : Priority
The priority of this listener.
"""
func = functools.partial(self._dispatch, 'listener', func)
priority = priority or Priority.NONE
args = list(args) + [functools.partial(self.dispatch, 'listener', func)]
if what == 'event':
li = self.bot.client.events.on(desc, func, priority=priority)
li = self.bot.client.events.on(*args, **kwargs)
elif what == 'packet':
li = self.bot.client.packets.on(desc, func, priority=priority)
li = self.bot.client.packets.on(*args, **kwargs)
else:
raise Exception('Invalid listener what: {}'.format(what))
@ -260,7 +320,7 @@ class Plugin(LoggingClass, PluginDeco):
def register_command(self, func, *args, **kwargs):
"""
Registers a command
Registers a command.
Parameters
----------
@ -272,11 +332,7 @@ class Plugin(LoggingClass, PluginDeco):
Keyword arguments to pass onto the :class:`disco.bot.command.Command`
object.
"""
if kwargs.pop('update', False) and func.__name__ in self.commands:
self.commands[func.__name__].update(*args, **kwargs)
else:
wrapped = functools.partial(self._dispatch, 'command', func)
self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs)
self.commands.append(Command(self, func, *args, **kwargs))
def register_schedule(self, func, interval, repeat=True, init=True):
"""
@ -289,8 +345,13 @@ class Plugin(LoggingClass, PluginDeco):
The function to be registered.
interval : int
Interval (in seconds) to repeat the function on.
repeat : bool
Whether this schedule is repeating (or one time).
init : bool
Whether to run this schedule once immediatly, or wait for the first
scheduled iteration.
"""
def repeat():
def repeat_func():
if init:
func()
@ -300,17 +361,17 @@ class Plugin(LoggingClass, PluginDeco):
if not repeat:
break
self.schedules[func.__name__] = self.spawn(repeat)
self.schedules[func.__name__] = self.spawn(repeat_func)
def load(self):
def load(self, ctx):
"""
Called when the plugin is loaded
Called when the plugin is loaded.
"""
self.bind_all()
pass
def unload(self):
def unload(self, ctx):
"""
Called when the plugin is unloaded
Called when the plugin is unloaded.
"""
for greenlet in self.greenlets:
greenlet.kill()

1
disco/bot/providers/disk.py

@ -13,6 +13,7 @@ class DiskProvider(BaseProvider):
self.fsync = config.get('fsync', False)
self.fsync_changes = config.get('fsync_changes', 1)
self.autosave_task = None
self.change_count = 0
def autosave_loop(self, interval):

25
disco/bot/providers/redis.py

@ -10,32 +10,39 @@ from .base import BaseProvider, SEP_SENTINEL
class RedisProvider(BaseProvider):
def __init__(self, config):
self.config = config
super(RedisProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.conn = None
def load(self):
self.redis = redis.Redis(
self.conn = redis.Redis(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379),
db=self.config.get('db', 0))
def exists(self, key):
return self.db.exists(key)
return self.conn.exists(key)
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
for key in self.db.scan_iter(u'{}*'.format(other)):
for key in self.conn.scan_iter(u'{}*'.format(other)):
key = key.decode('utf-8')
if key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key, value in izip(keys, self.db.mget(keys)):
keys = list(keys)
if not len(keys):
raise StopIteration
for key, value in izip(keys, self.conn.mget(keys)):
yield (key, Serializer.loads(self.format, value))
def get(self, key):
return Serializer.loads(self.format, self.db.get(key))
return Serializer.loads(self.format, self.conn.get(key))
def set(self, key, value):
self.db.set(key, Serializer.dumps(self.format, value))
self.conn.set(key, Serializer.dumps(self.format, value))
def delete(self, key, value):
self.db.delete(key)
def delete(self, key):
self.conn.delete(key)

6
disco/bot/providers/rocksdb.py

@ -12,11 +12,13 @@ from .base import BaseProvider, SEP_SENTINEL
class RocksDBProvider(BaseProvider):
def __init__(self, config):
self.config = config
super(RocksDBProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage.db')
self.db = None
def k(self, k):
@staticmethod
def k(k):
return bytes(k) if six.PY3 else str(k.encode('utf-8'))
def load(self):

17
disco/cli.py

@ -18,14 +18,13 @@ parser.add_argument('--config', help='Configuration file', default='config.yaml'
parser.add_argument('--token', help='Bot Authentication Token', default=None)
parser.add_argument('--shard-count', help='Total number of shards', default=None)
parser.add_argument('--shard-id', help='Current shard number/id', default=None)
parser.add_argument('--shard-auto', help='Automatically run all shards', action='store_true', default=False)
parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None)
parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None)
parser.add_argument('--encoder', help='encoder for gateway data', default=None)
parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False)
parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[])
logging.basicConfig(level=logging.INFO)
def disco_main(run=False):
"""
@ -42,6 +41,7 @@ def disco_main(run=False):
from disco.client import Client, ClientConfig
from disco.bot import Bot, BotConfig
from disco.util.token import is_valid_token
from disco.util.logging import setup_logging
if os.path.exists(args.config):
config = ClientConfig.from_file(args.config)
@ -56,12 +56,23 @@ def disco_main(run=False):
print('Invalid token passed')
return
if args.shard_auto:
from disco.gateway.sharder import AutoSharder
AutoSharder(config).run()
return
# TODO: make configurable
setup_logging(level=logging.INFO)
client = Client(config)
bot = None
if args.run_bot or hasattr(config, 'bot'):
bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig()
bot_config.plugins += args.plugin
if not hasattr(bot_config, 'plugins'):
bot_config.plugins = args.plugin
else:
bot_config.plugins += args.plugin
bot = Bot(client, bot_config)
if run:

45
disco/client.py

@ -1,3 +1,4 @@
import time
import gevent
from holster.emitter import Emitter
@ -5,24 +6,28 @@ from holster.emitter import Emitter
from disco.state import State, StateConfig
from disco.api.client import APIClient
from disco.gateway.client import GatewayClient
from disco.gateway.packets import OPCode
from disco.types.user import Status, Game
from disco.util.config import Config
from disco.util.logging import LoggingClass
from disco.util.backdoor import DiscoBackdoorServer
class ClientConfig(LoggingClass, Config):
class ClientConfig(Config):
"""
Configuration for the :class:`Client`.
Attributes
----------
token : str
Discord authentication token, ca be validated using the
Discord authentication token, can be validated using the
:func:`disco.util.token.is_valid_token` function.
shard_id : int
The shard ID for the current client instance.
shard_count : int
The total count of shards running.
max_reconnects : int
The maximum number of connection retries to make before giving up (0 = never give up).
manhole_enable : bool
Whether to enable the manhole (e.g. console backdoor server) utility.
manhole_bind : tuple(str, int)
@ -36,14 +41,15 @@ class ClientConfig(LoggingClass, Config):
token = ""
shard_id = 0
shard_count = 1
max_reconnects = 5
manhole_enable = True
manhole_enable = False
manhole_bind = ('127.0.0.1', 8484)
encoder = 'json'
class Client(object):
class Client(LoggingClass):
"""
Class representing the base entry point that should be used in almost all
implementation cases. This class wraps the functionality of both the REST API
@ -82,8 +88,8 @@ class Client(object):
self.events = Emitter(gevent.spawn)
self.packets = Emitter(gevent.spawn)
self.api = APIClient(self)
self.gw = GatewayClient(self, self.config.encoder)
self.api = APIClient(self.config.token, self)
self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder)
self.state = State(self, StateConfig(self.config.get('state', {})))
if self.config.manhole_enable:
@ -95,18 +101,37 @@ class Client(object):
}
self.manhole = DiscoBackdoorServer(self.config.manhole_bind,
banner='Disco Manhole',
localf=lambda: self.manhole_locals)
banner='Disco Manhole',
localf=lambda: self.manhole_locals)
self.manhole.start()
def update_presence(self, game=None, status=None, afk=False, since=0.0):
if game and not isinstance(game, Game):
raise TypeError('Game must be a Game model')
if status is Status.IDLE and not since:
since = int(time.time() * 1000)
payload = {
'afk': afk,
'since': since,
'status': status.value.lower(),
'game': None,
}
if game:
payload['game'] = game.to_dict()
self.gw.send(OPCode.STATUS_UPDATE, payload)
def run(self):
"""
Run the client (e.g. the :class:`GatewayClient`) in a new greenlet
Run the client (e.g. the :class:`GatewayClient`) in a new greenlet.
"""
return gevent.spawn(self.gw.run)
def run_forever(self):
"""
Run the client (e.g. the :class:`GatewayClient`) in the current greenlet
Run the client (e.g. the :class:`GatewayClient`) in the current greenlet.
"""
return self.gw.run()

52
disco/gateway/client.py

@ -15,16 +15,21 @@ TEN_MEGABYTES = 10490000
class GatewayClient(LoggingClass):
GATEWAY_VERSION = 6
MAX_RECONNECTS = 5
def __init__(self, client, encoder='json'):
def __init__(self, client, max_reconnects=5, encoder='json', ipc=None):
super(GatewayClient, self).__init__()
self.client = client
self.max_reconnects = max_reconnects
self.encoder = ENCODERS[encoder]
self.events = client.events
self.packets = client.packets
# IPC for shards
if ipc:
self.shards = ipc.get_shards()
self.ipc = ipc
# Its actually 60, 120 but lets give ourselves a buffer
self.limiter = SimpleLimiter(60, 130)
@ -37,6 +42,7 @@ class GatewayClient(LoggingClass):
# Bind to ready payload
self.events.on('Ready', self.on_ready)
self.events.on('Resumed', self.on_resumed)
# Websocket connection
self.ws = None
@ -76,15 +82,15 @@ class GatewayClient(LoggingClass):
self.log.debug('Dispatching %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet):
def handle_heartbeat(self, _):
self._send(OPCode.HEARTBEAT, self.seq)
def handle_reconnect(self, packet):
def handle_reconnect(self, _):
self.log.warning('Received RECONNECT request, forcing a fresh reconnect')
self.session_id = None
self.ws.close()
def handle_invalid_session(self, packet):
def handle_invalid_session(self, _):
self.log.warning('Recieved INVALID_SESSION, forcing a fresh reconnect')
self.session_id = None
self.ws.close()
@ -98,14 +104,21 @@ class GatewayClient(LoggingClass):
self.session_id = ready.session_id
self.reconnects = 0
def connect_and_run(self):
if not self._cached_gateway_url:
self._cached_gateway_url = self.client.api.gateway(
version=self.GATEWAY_VERSION,
encoding=self.encoder.TYPE)
def on_resumed(self, _):
self.log.info('Recieved RESUMED')
self.reconnects = 0
def connect_and_run(self, gateway_url=None):
if not gateway_url:
if not self._cached_gateway_url:
self._cached_gateway_url = self.client.api.gateway_get()['url']
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url)
self.ws = Websocket(self._cached_gateway_url)
gateway_url = self._cached_gateway_url
gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE)
self.log.info('Opening websocket connection to URL `%s`', gateway_url)
self.ws = Websocket(gateway_url)
self.ws.emitter.on('on_open', self.on_open)
self.ws.emitter.on('on_error', self.on_error)
self.ws.emitter.on('on_close', self.on_close)
@ -153,8 +166,8 @@ class GatewayClient(LoggingClass):
'compress': True,
'large_threshold': 250,
'shard': [
self.client.config.shard_id,
self.client.config.shard_count,
int(self.client.config.shard_id),
int(self.client.config.shard_count),
],
'properties': {
'$os': 'linux',
@ -165,15 +178,22 @@ class GatewayClient(LoggingClass):
})
def on_close(self, code, reason):
# Kill heartbeater, a reconnect/resume will trigger a HELLO which will
# respawn it
if self._heartbeat_task:
self._heartbeat_task.kill()
# If we're quitting, just break out of here
if self.shutting_down:
self.log.info('WS Closed: shutting down')
return
# Track reconnect attempts
self.reconnects += 1
self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects)
if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS:
raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS))
if self.max_reconnects and self.reconnects > self.max_reconnects:
raise Exception('Failed to reconnect after {} attempts, giving up'.format(self.max_reconnects))
# Don't resume for these error codes
if code and 4000 <= code <= 4010:

4
disco/gateway/encoding/base.py

@ -1,7 +1,9 @@
from websocket import ABNF
from holster.interface import Interface
class BaseEncoder(object):
class BaseEncoder(Interface):
TYPE = None
OPCODE = ABNF.OPCODE_TEXT

2
disco/gateway/encoding/json.py

@ -1,7 +1,5 @@
from __future__ import absolute_import, print_function
import six
try:
import ujson as json
except ImportError:

424
disco/gateway/events.py

@ -4,20 +4,20 @@ import inflection
import six
from disco.types.user import User, Presence
from disco.types.channel import Channel
from disco.types.message import Message
from disco.types.channel import Channel, PermissionOverwrite
from disco.types.message import Message, MessageReactionEmoji
from disco.types.voice import VoiceState
from disco.types.guild import Guild, GuildMember, Role
from disco.types.guild import Guild, GuildMember, Role, GuildEmoji
from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime
from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, datetime
# Mapping of discords event name to our event classes
EVENTS_MAP = {}
class GatewayEventMeta(ModelMeta):
def __new__(cls, name, parents, dct):
obj = super(GatewayEventMeta, cls).__new__(cls, name, parents, dct)
def __new__(mcs, name, parents, dct):
obj = super(GatewayEventMeta, mcs).__new__(mcs, name, parents, dct)
if name != 'GatewayEvent':
EVENTS_MAP[inflection.underscore(name).upper()] = obj
@ -64,22 +64,21 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)):
return cls(obj, client)
def __getattr__(self, name):
if hasattr(self, '_wraps_model'):
modname, _ = self._wraps_model
if hasattr(self, modname) and hasattr(getattr(self, modname), name):
return getattr(getattr(self, modname), name)
raise AttributeError(name)
if hasattr(self, '_proxy'):
return getattr(getattr(self, self._proxy), name)
return object.__getattribute__(self, name)
def debug(func=None):
def debug(func=None, match=None):
def deco(cls):
old_init = cls.__init__
def new_init(self, obj, *args, **kwargs):
if func:
print(func(obj))
else:
print(obj)
if not match or match(obj):
if func:
print(func(obj))
else:
print(obj)
old_init(self, obj, *args, **kwargs)
@ -93,8 +92,16 @@ def wraps_model(model, alias=None):
def deco(cls):
cls._fields[alias] = Field(model)
cls._fields[alias].set_name(alias)
cls._fields[alias].name = alias
cls._wraps_model = (alias, model)
cls._proxy = alias
return cls
return deco
def proxy(field):
def deco(cls):
cls._proxy = field
return cls
return deco
@ -103,49 +110,102 @@ class Ready(GatewayEvent):
"""
Sent after the initial gateway handshake is complete. Contains data required
for bootstrapping the client's states.
Attributes
-----
version : int
The gateway version.
session_id : str
The session ID.
user : :class:`disco.types.user.User`
The user object for the authed account.
guilds : list[:class:`disco.types.guild.Guild`
All guilds this account is a member of. These are shallow guild objects.
private_channels list[:class:`disco.types.channel.Channel`]
All private channels (DMs) open for this account.
"""
version = Field(int, alias='v')
session_id = Field(str)
user = Field(User)
guilds = Field(listof(Guild))
private_channels = Field(listof(Channel))
guilds = ListField(Guild)
private_channels = ListField(Channel)
trace = ListField(str, alias='_trace')
class Resumed(GatewayEvent):
"""
Sent after a resume completes.
"""
pass
trace = ListField(str, alias='_trace')
@wraps_model(Guild)
class GuildCreate(GatewayEvent):
"""
Sent when a guild is created, or becomes available.
Sent when a guild is joined, or becomes available.
Attributes
-----
guild : :class:`disco.types.guild.Guild`
The guild being created (e.g. joined)
unavailable : bool
If false, this guild is coming online from a previously unavailable state,
and if None, this is a normal guild join event.
"""
unavailable = Field(bool)
@property
def created(self):
"""
Shortcut property which is true when we actually joined the guild.
"""
return self.unavailable is None
@wraps_model(Guild)
class GuildUpdate(GatewayEvent):
"""
Sent when a guild is updated.
Attributes
-----
guild : :class:`disco.types.guild.Guild`
The updated guild object.
"""
pass
class GuildDelete(GatewayEvent):
"""
Sent when a guild is deleted, or becomes unavailable.
Sent when a guild is deleted, left, or becomes unavailable.
Attributes
-----
id : snowflake
The ID of the guild being deleted.
unavailable : bool
If true, this guild is becoming unavailable, if None this is a normal
guild leave event.
"""
id = Field(snowflake)
unavailable = Field(bool)
@property
def deleted(self):
"""
Shortcut property which is true when we actually have left the guild.
"""
return self.unavailable is None
@wraps_model(Channel)
class ChannelCreate(GatewayEvent):
"""
Sent when a channel is created.
Attributes
-----
channel : :class:`disco.types.channel.Channel`
The channel which was created.
"""
@ -153,115 +213,236 @@ class ChannelCreate(GatewayEvent):
class ChannelUpdate(ChannelCreate):
"""
Sent when a channel is updated.
Attributes
-----
channel : :class:`disco.types.channel.Channel`
The channel which was updated.
"""
pass
overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites')
@wraps_model(Channel)
class ChannelDelete(ChannelCreate):
"""
Sent when a channel is deleted.
Attributes
-----
channel : :class:`disco.types.channel.Channel`
The channel being deleted.
"""
pass
class ChannelPinsUpdate(GatewayEvent):
"""
Sent when a channel's pins are updated.
Attributes
-----
channel_id : snowflake
ID of the channel where pins where updated.
last_pin_timestap : datetime
The time the last message was pinned.
"""
channel_id = Field(snowflake)
last_pin_timestamp = Field(lazy_datetime)
last_pin_timestamp = Field(datetime)
@wraps_model(User)
@proxy(User)
class GuildBanAdd(GatewayEvent):
"""
Sent when a user is banned from a guild.
Attributes
-----
guild_id : snowflake
The ID of the guild the user is being banned from.
user : :class:`disco.types.user.User`
The user being banned from the guild.
"""
pass
guild_id = Field(snowflake)
user = Field(User)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(User)
@proxy(User)
class GuildBanRemove(GuildBanAdd):
"""
Sent when a user is unbanned from a guild.
Attributes
-----
guild_id : snowflake
The ID of the guild the user is being unbanned from.
user : :class:`disco.types.user.User`
The user being unbanned from the guild.
"""
pass
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class GuildEmojisUpdate(GatewayEvent):
"""
Sent when a guild's emojis are updated.
Attributes
-----
guild_id : snowflake
The ID of the guild the emojis are being updated in.
emojis : list[:class:`disco.types.guild.Emoji`]
The new set of emojis for the guild
"""
pass
guild_id = Field(snowflake)
emojis = ListField(GuildEmoji)
class GuildIntegrationsUpdate(GatewayEvent):
"""
Sent when a guild's integrations are updated.
Attributes
-----
guild_id : snowflake
The ID of the guild integrations where updated in.
"""
pass
guild_id = Field(snowflake)
class GuildMembersChunk(GatewayEvent):
"""
Sent in response to a member's chunk request.
Attributes
-----
guild_id : snowflake
The ID of the guild this member chunk is for.
members : list[:class:`disco.types.guild.GuildMember`]
The chunk of members.
"""
guild_id = Field(snowflake)
members = Field(listof(GuildMember))
members = ListField(GuildMember)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(GuildMember, alias='member')
class GuildMemberAdd(GatewayEvent):
"""
Sent when a user joins a guild.
Attributes
-----
member : :class:`disco.types.guild.GuildMember`
The member that has joined the guild.
"""
pass
@proxy('user')
class GuildMemberRemove(GatewayEvent):
"""
Sent when a user leaves a guild (via leaving, kicking, or banning).
Attributes
-----
guild_id : snowflake
The ID of the guild the member left from.
user : :class:`disco.types.user.User`
The user who was removed from the guild.
"""
guild_id = Field(snowflake)
user = Field(User)
guild_id = Field(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(GuildMember, alias='member')
class GuildMemberUpdate(GatewayEvent):
"""
Sent when a guilds member is updated.
Attributes
-----
member : :class:`disco.types.guild.GuildMember`
The member being updated
"""
pass
@proxy('role')
class GuildRoleCreate(GatewayEvent):
"""
Sent when a role is created.
Attributes
-----
guild_id : snowflake
The ID of the guild where the role was created.
role : :class:`disco.types.guild.Role`
The role that was created.
"""
guild_id = Field(snowflake)
role = Field(Role)
guild_id = Field(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@proxy('role')
class GuildRoleUpdate(GuildRoleCreate):
"""
Sent when a role is updated.
Attributes
-----
guild_id : snowflake
The ID of the guild where the role was created.
role : :class:`disco.types.guild.Role`
The role that was created.
"""
pass
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class GuildRoleDelete(GatewayEvent):
"""
Sent when a role is deleted.
Attributes
-----
guild_id : snowflake
The ID of the guild where the role is being deleted.
role_id : snowflake
The id of the role being deleted.
"""
guild_id = Field(snowflake)
role_id = Field(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@wraps_model(Message)
class MessageCreate(GatewayEvent):
"""
Sent when a message is created.
Attributes
-----
message : :class:`disco.types.message.Message`
The message being created.
"""
@ -269,55 +450,124 @@ class MessageCreate(GatewayEvent):
class MessageUpdate(MessageCreate):
"""
Sent when a message is updated/edited.
Attributes
-----
message : :class:`disco.types.message.Message`
The message being updated.
"""
pass
class MessageDelete(GatewayEvent):
"""
Sent when a message is deleted.
Attributes
-----
id : snowflake
The ID of message being deleted.
channel_id : snowflake
The ID of the channel the message was deleted in.
"""
id = Field(snowflake)
channel_id = Field(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageDeleteBulk(GatewayEvent):
"""
Sent when multiple messages are deleted from a channel.
Attributes
-----
channel_id : snowflake
The channel the messages are being deleted in.
ids : list[snowflake]
List of messages being deleted in the channel.
"""
channel_id = Field(snowflake)
ids = Field(listof(snowflake))
ids = ListField(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
@wraps_model(Presence)
class PresenceUpdate(GatewayEvent):
"""
Sent when a user's presence is updated.
Attributes
-----
presence : :class:`disco.types.user.Presence`
The updated presence object.
guild_id : snowflake
The guild this presence update is for.
roles : list[snowflake]
List of roles the user from the presence is part of.
"""
guild_id = Field(snowflake)
roles = Field(listof(snowflake))
roles = ListField(snowflake)
@property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class TypingStart(GatewayEvent):
"""
Sent when a user begins typing in a channel.
Attributes
-----
channel_id : snowflake
The ID of the channel where the user is typing.
user_id : snowflake
The ID of the user who is typing.
timestamp : datetime
When the user started typing.
"""
channel_id = Field(snowflake)
user_id = Field(snowflake)
timestamp = Field(snowflake)
timestamp = Field(datetime)
@wraps_model(VoiceState, alias='state')
class VoiceStateUpdate(GatewayEvent):
"""
Sent when a users voice state changes.
Attributes
-----
state : :class:`disco.models.voice.VoiceState`
The voice state which was updated.
"""
pass
class VoiceServerUpdate(GatewayEvent):
"""
Sent when a voice server is updated.
Attributes
-----
token : str
The token for the voice server.
endpoint : str
The endpoint for the voice server.
guild_id : snowflake
The guild ID this voice server update is for.
"""
token = Field(str)
endpoint = Field(str)
@ -327,6 +577,94 @@ class VoiceServerUpdate(GatewayEvent):
class WebhooksUpdate(GatewayEvent):
"""
Sent when a channels webhooks are updated.
Attributes
-----
channel_id : snowflake
The channel ID this webhooks update is for.
guild_id : snowflake
The guild ID this webhooks update is for.
"""
channel_id = Field(snowflake)
guild_id = Field(snowflake)
class MessageReactionAdd(GatewayEvent):
"""
Sent when a reaction is added to a message.
Attributes
----------
channel_id : snowflake
The channel ID the message is in.
messsage_id : snowflake
The ID of the message for which the reaction was added too.
user_id : snowflake
The ID of the user who added the reaction.
emoji : :class:`disco.types.message.MessageReactionEmoji`
The emoji which was added.
"""
channel_id = Field(snowflake)
message_id = Field(snowflake)
user_id = Field(snowflake)
emoji = Field(MessageReactionEmoji)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageReactionRemove(GatewayEvent):
"""
Sent when a reaction is removed from a message.
Attributes
----------
channel_id : snowflake
The channel ID the message is in.
messsage_id : snowflake
The ID of the message for which the reaction was removed from.
user_id : snowflake
The ID of the user who originally added the reaction.
emoji : :class:`disco.types.message.MessageReactionEmoji`
The emoji which was removed.
"""
channel_id = Field(snowflake)
message_id = Field(snowflake)
user_id = Field(snowflake)
emoji = Field(MessageReactionEmoji)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageReactionRemoveAll(GatewayEvent):
"""
Sent when all reactions are removed from a message.
Attributes
----------
channel_id : snowflake
The channel ID the message is in.
message_id : snowflake
The ID of the message for which the reactions where removed from.
"""
channel_id = Field(snowflake)
message_id = Field(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild

91
disco/gateway/ipc.py

@ -0,0 +1,91 @@
import random
import gevent
import string
import weakref
from holster.enum import Enum
from disco.util.logging import LoggingClass
from disco.util.serializer import dump_function, load_function
def get_random_str(size):
return ''.join([random.choice(string.printable) for _ in range(size)])
IPCMessageType = Enum(
'CALL_FUNC',
'GET_ATTR',
'EXECUTE',
'RESPONSE',
)
class GIPCProxy(LoggingClass):
def __init__(self, obj, pipe):
super(GIPCProxy, self).__init__()
self.obj = obj
self.pipe = pipe
self.results = weakref.WeakValueDictionary()
gevent.spawn(self.read_loop)
def resolve(self, parts):
base = self.obj
for part in parts:
base = getattr(base, part)
return base
def send(self, typ, data):
self.pipe.put((typ.value, data))
def handle(self, mtype, data):
if mtype == IPCMessageType.CALL_FUNC:
nonce, func, args, kwargs = data
res = self.resolve(func)(*args, **kwargs)
self.send(IPCMessageType.RESPONSE, (nonce, res))
elif mtype == IPCMessageType.GET_ATTR:
nonce, path = data
self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path)))
elif mtype == IPCMessageType.EXECUTE:
nonce, raw = data
func = load_function(raw)
try:
result = func(self.obj)
except Exception:
self.log.exception('Failed to EXECUTE: ')
result = None
self.send(IPCMessageType.RESPONSE, (nonce, result))
elif mtype == IPCMessageType.RESPONSE:
nonce, res = data
if nonce in self.results:
self.results[nonce].set(res)
def read_loop(self):
while True:
mtype, data = self.pipe.get()
try:
self.handle(mtype, data)
except:
self.log.exception('Error in GIPCProxy:')
def execute(self, func):
nonce = get_random_str(32)
raw = dump_function(func)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw)))
return result
def get(self, path):
nonce = get_random_str(32)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.GET_ATTR.value, (nonce, path)))
return result
def call(self, path, *args, **kwargs):
nonce = get_random_str(32)
self.results[nonce] = result = gevent.event.AsyncResult()
self.pipe.put((IPCMessageType.CALL_FUNC.value, (nonce, path, args, kwargs)))
return result

4
disco/gateway/packets.py

@ -1,7 +1,7 @@
from holster.enum import Enum
SEND = object()
RECV = object()
SEND = 1
RECV = 2
OPCode = Enum(
DISPATCH=0,

104
disco/gateway/sharder.py

@ -0,0 +1,104 @@
from __future__ import absolute_import
import gipc
import gevent
import pickle
import logging
import marshal
from six.moves import range
from disco.client import Client
from disco.bot import Bot, BotConfig
from disco.api.client import APIClient
from disco.gateway.ipc import GIPCProxy
from disco.util.logging import setup_logging
from disco.util.snowflake import calculate_shard
from disco.util.serializer import dump_function, load_function
def run_shard(config, shard_id, pipe):
setup_logging(
level=logging.INFO,
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(shard_id)
)
config.shard_id = shard_id
client = Client(config)
bot = Bot(client, BotConfig(config.bot))
bot.sharder = GIPCProxy(bot, pipe)
bot.shards = ShardHelper(config.shard_count, bot)
bot.run_forever()
class ShardHelper(object):
def __init__(self, count, bot):
self.count = count
self.bot = bot
def keys(self):
for sid in range(self.count):
yield sid
def on(self, id, func):
if id == self.bot.client.config.shard_id:
result = gevent.event.AsyncResult()
result.set(func(self.bot))
return result
return self.bot.sharder.call(('run_on', ), id, dump_function(func))
def all(self, func, timeout=None):
pool = gevent.pool.Pool(self.count)
return dict(zip(range(self.count), pool.imap(lambda i: self.on(i, func).wait(timeout=timeout), range(self.count))))
def for_id(self, sid, func):
shard = calculate_shard(self.count, sid)
return self.on(shard, func)
class AutoSharder(object):
def __init__(self, config):
self.config = config
self.client = APIClient(config.token)
self.shards = {}
self.config.shard_count = self.client.gateway_bot_get()['shards']
def run_on(self, sid, raw):
func = load_function(raw)
return self.shards[sid].execute(func).wait(timeout=15)
def run(self):
for shard in range(self.config.shard_count):
if self.config.manhole_enable and shard != 0:
self.config.manhole_enable = False
self.start_shard(shard)
gevent.sleep(6)
logging.basicConfig(
level=logging.INFO,
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id)
)
@staticmethod
def dumps(data):
if isinstance(data, (basestring, int, long, bool, list, set, dict)):
return '\x01' + marshal.dumps(data)
elif isinstance(data, object) and data.__class__.__name__ == 'code':
return '\x01' + marshal.dumps(data)
else:
return '\x02' + pickle.dumps(data)
@staticmethod
def loads(data):
enc_type = data[0]
if enc_type == '\x01':
return marshal.loads(data[1:])
elif enc_type == '\x02':
return pickle.loads(data[1:])
def start_shard(self, sid):
cpipe, ppipe = gipc.pipe(duplex=True, encoder=self.dumps, decoder=self.loads)
gipc.start_process(run_shard, (self.config, sid, cpipe))
self.shards[sid] = GIPCProxy(self, ppipe)

68
disco/state.py

@ -1,10 +1,11 @@
import six
import weakref
import inflection
from collections import deque, namedtuple
from weakref import WeakValueDictionary
from gevent.event import Event
from disco.types.base import UNSET
from disco.util.config import Config
from disco.util.hashmap import HashMap, DefaultHashMap
@ -88,7 +89,7 @@ class State(object):
EVENTS = [
'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove',
'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete',
'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate',
'GuildEmojisUpdate', 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate',
'PresenceUpdate'
]
@ -102,9 +103,9 @@ class State(object):
self.me = None
self.dms = HashMap()
self.guilds = HashMap()
self.channels = HashMap(WeakValueDictionary())
self.users = HashMap(WeakValueDictionary())
self.voice_states = HashMap(WeakValueDictionary())
self.channels = HashMap(weakref.WeakValueDictionary())
self.users = HashMap(weakref.WeakValueDictionary())
self.voice_states = HashMap(weakref.WeakValueDictionary())
# If message tracking is enabled, listen to those events
if self.config.track_messages:
@ -117,7 +118,7 @@ class State(object):
def unbind(self):
"""
Unbinds all bound event listeners for this state object
Unbinds all bound event listeners for this state object.
"""
map(lambda k: k.unbind(), self.listeners)
self.listeners = []
@ -185,11 +186,19 @@ class State(object):
for member in six.itervalues(event.guild.members):
self.users[member.user.id] = member.user
for voice_state in six.itervalues(event.guild.voice_states):
self.voice_states[voice_state.session_id] = voice_state
if self.config.sync_guild_members:
event.guild.sync()
def on_guild_update(self, event):
self.guilds[event.guild.id].update(event.guild)
self.guilds[event.guild.id].update(event.guild, ignored=[
'channels',
'members',
'voice_states',
'presences'
])
def on_guild_delete(self, event):
if event.id in self.guilds:
@ -208,6 +217,10 @@ class State(object):
if event.channel.id in self.channels:
self.channels[event.channel.id].update(event.channel)
if event.overwrites is not UNSET:
self.channels[event.channel.id].overwrites = event.overwrites
self.channels[event.channel.id].after_load()
def on_channel_delete(self, event):
if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels:
del event.channel.guild.channels[event.channel.id]
@ -215,18 +228,22 @@ class State(object):
del self.dms[event.channel.id]
def on_voice_state_update(self, event):
# Happy path: we have the voice state and want to update/delete it
guild = self.guilds.get(event.state.guild_id)
if not guild:
return
if event.state.session_id in guild.voice_states:
# Existing connection, we are either moving channels or disconnecting
if event.state.session_id in self.voice_states:
# Moving channels
if event.state.channel_id:
guild.voice_states[event.state.session_id].update(event.state)
self.voice_states[event.state.session_id].update(event.state)
# Disconnection
else:
del guild.voice_states[event.state.session_id]
if event.state.guild_id in self.guilds:
if event.state.session_id in self.guilds[event.state.guild_id].voice_states:
del self.guilds[event.state.guild_id].voice_states[event.state.session_id]
del self.voice_states[event.state.session_id]
# New connection
elif event.state.channel_id:
guild.voice_states[event.state.session_id] = event.state
if event.state.guild_id in self.guilds:
self.guilds[event.state.guild_id].voice_states[event.state.session_id] = event.state
self.voice_states[event.state.session_id] = event.state
def on_guild_member_add(self, event):
if event.member.user.id not in self.users:
@ -243,6 +260,9 @@ class State(object):
if event.member.guild_id not in self.guilds:
return
if event.member.id not in self.guilds[event.member.guild_id].members:
return
self.guilds[event.member.guild_id].members[event.member.id].update(event.member)
def on_guild_member_remove(self, event):
@ -285,6 +305,22 @@ class State(object):
del self.guilds[event.guild_id].roles[event.role_id]
def on_guild_emojis_update(self, event):
if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis})
def on_presence_update(self, event):
if event.user.id in self.users:
self.users[event.user.id].update(event.presence.user)
self.users[event.user.id].presence = event.presence
event.presence.user = self.users[event.user.id]
if event.guild_id not in self.guilds:
return
if event.user.id not in self.guilds[event.guild_id].members:
return
self.guilds[event.guild_id].members[event.user.id].user.update(event.user)

1
disco/types/__init__.py

@ -1,3 +1,4 @@
from disco.types.base import UNSET
from disco.types.channel import Channel
from disco.types.guild import Guild, GuildMember, Role
from disco.types.user import User

272
disco/types/base.py

@ -3,7 +3,7 @@ import gevent
import inspect
import functools
from holster.enum import BaseEnumMeta
from holster.enum import BaseEnumMeta, EnumAttr
from datetime import datetime as real_datetime
from disco.util.functional import CachedSlotProperty
@ -15,45 +15,61 @@ DATETIME_FORMATS = [
]
def get_item_by_path(obj, path):
for part in path.split('.'):
obj = getattr(obj, part)
return obj
class Unset(object):
def __nonzero__(self):
return False
UNSET = Unset()
class ConversionError(Exception):
def __init__(self, field, raw, e):
super(ConversionError, self).__init__(
'Failed to convert `{}` (`{}`) to {}: {}'.format(
str(raw)[:144], field.src_name, field.typ, e))
str(raw)[:144], field.src_name, field.true_type, e))
if six.PY3:
self.__cause__ = e
class FieldType(object):
def __init__(self, typ):
if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model):
self.typ = typ
elif isinstance(typ, BaseEnumMeta):
self.typ = lambda raw, _: typ.get(raw)
elif typ is None:
self.typ = lambda x, y: None
else:
self.typ = lambda raw, _: typ(raw)
def try_convert(self, raw, client):
pass
def __call__(self, raw, client):
return self.try_convert(raw, client)
class Field(object):
def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, cast=None, **kwargs):
# TODO: fix default bullshit
self.true_type = value_type
self.src_name = alias
self.dst_name = None
self.ignore_dump = ignore_dump or []
self.cast = cast
self.metadata = kwargs
if default is not None:
self.default = default
elif not hasattr(self, 'default'):
self.default = None
class Field(FieldType):
def __init__(self, typ, alias=None, default=None):
super(Field, self).__init__(typ)
self.deserializer = None
# Set names
self.src_name = alias
self.dst_name = None
if value_type:
self.deserializer = self.type_to_deserializer(value_type)
self.default = default
if isinstance(self.deserializer, Field) and self.default is None:
self.default = self.deserializer.default
elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None and create:
self.default = self.deserializer
if isinstance(self.typ, FieldType):
self.default = self.typ.default
@property
def name(self):
return None
def set_name(self, name):
@name.setter
def name(self, name):
if not self.dst_name:
self.dst_name = name
@ -65,31 +81,82 @@ class Field(FieldType):
def try_convert(self, raw, client):
try:
return self.typ(raw, client)
return self.deserializer(raw, client)
except Exception as e:
six.raise_from(ConversionError(self, raw, e), e)
six.reraise(ConversionError, ConversionError(self, raw, e))
@staticmethod
def type_to_deserializer(typ):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ
elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw)
elif typ is None:
return lambda x, y: None
else:
return lambda raw, _: typ(raw)
@staticmethod
def serialize(value, inst=None):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict(ignore=(inst.ignore_dump if inst else []))
else:
if inst and inst.cast:
return inst.cast(value)
return value
class _Dict(FieldType):
def __call__(self, raw, client):
return self.try_convert(raw, client)
class DictField(Field):
default = HashMap
def __init__(self, typ, key=None):
super(_Dict, self).__init__(typ)
self.key = key
def __init__(self, key_type, value_type=None, **kwargs):
super(DictField, self).__init__({}, **kwargs)
self.true_key_type = key_type
self.true_value_type = value_type
self.key_de = self.type_to_deserializer(key_type)
self.value_de = self.type_to_deserializer(value_type or key_type)
@staticmethod
def serialize(value, inst=None):
return {
Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)
if k not in (inst.ignore_dump if inst else [])
}
def try_convert(self, raw, client):
if self.key:
converted = [self.typ(i, client) for i in raw]
return HashMap({getattr(i, self.key): i for i in converted})
else:
return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)})
return HashMap({
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
})
class _List(FieldType):
class ListField(Field):
default = list
@staticmethod
def serialize(value, inst=None):
return list(map(Field.serialize, value))
def try_convert(self, raw, client):
return [self.deserializer(i, client) for i in raw]
class AutoDictField(Field):
default = HashMap
def __init__(self, value_type, key, **kwargs):
super(AutoDictField, self).__init__({}, **kwargs)
self.value_de = self.type_to_deserializer(value_type)
self.key = key
def try_convert(self, raw, client):
return [self.typ(i, client) for i in raw]
return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
})
def _make(typ, data, client):
@ -104,37 +171,19 @@ def snowflake(data):
def enum(typ):
def _f(data):
if isinstance(data, str):
data = data.lower()
return typ.get(data) if data is not None else None
return _f
def listof(*args, **kwargs):
return _List(*args, **kwargs)
def dictof(*args, **kwargs):
return _Dict(*args, **kwargs)
def lazy_datetime(data):
if not data:
return property(lambda: None)
def get():
for fmt in DATETIME_FORMATS:
try:
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
except (ValueError, TypeError):
continue
raise ValueError('Failed to conver `{}` to datetime'.format(data))
return property(get)
def datetime(data):
if not data:
return None
if isinstance(data, int):
return real_datetime.utcfromtimestamp(data)
for fmt in DATETIME_FORMATS:
try:
return real_datetime.strptime(data.rsplit('+', 1)[0], fmt)
@ -145,6 +194,9 @@ def datetime(data):
def text(obj):
if obj is None:
return None
if six.PY2:
if isinstance(obj, str):
return obj.decode('utf-8')
@ -154,6 +206,9 @@ def text(obj):
def binary(obj):
if obj is None:
return None
if six.PY2:
if isinstance(obj, str):
return obj.decode('utf-8')
@ -165,13 +220,16 @@ def binary(obj):
def with_equality(field):
class T(object):
def __eq__(self, other):
return getattr(self, field) == getattr(other, field)
if isinstance(other, self.__class__):
return getattr(self, field) == getattr(other, field)
else:
return getattr(self, field) == other
return T
def with_hash(field):
class T(object):
def __hash__(self, other):
def __hash__(self):
return hash(getattr(self, field))
return T
@ -182,7 +240,7 @@ SlottedModel = None
class ModelMeta(type):
def __new__(cls, name, parents, dct):
def __new__(mcs, name, parents, dct):
fields = {}
for parent in parents:
@ -193,7 +251,7 @@ class ModelMeta(type):
if not isinstance(v, Field):
continue
v.set_name(k)
v.name = k
fields[k] = v
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):
@ -209,7 +267,7 @@ class ModelMeta(type):
dct = {k: v for k, v in six.iteritems(dct) if k not in fields}
dct['_fields'] = fields
return super(ModelMeta, cls).__new__(cls, name, parents, dct)
return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
class AsyncChainable(object):
@ -233,23 +291,49 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
else:
obj = kwargs
for name, field in six.iteritems(self.__class__._fields):
if field.src_name not in obj or obj[field.src_name] is None:
if field.has_default():
default = field.default() if callable(field.default) else field.default
else:
default = None
setattr(self, field.dst_name, default)
self.load(obj)
self.validate()
def validate(self):
pass
@property
def _fields(self):
return self.__class__._fields
def load(self, obj, consume=False, skip=None):
return self.load_into(self, obj, consume, skip)
def load_into(self, inst, obj, consume=False, skip=None):
for name, field in six.iteritems(self._fields):
should_skip = skip and name in skip
if consume and not should_skip:
raw = obj.pop(field.src_name, UNSET)
else:
raw = obj.get(field.src_name, UNSET)
# If the field is unset/none, and we have a default we need to set it
if (raw in (None, UNSET) or should_skip) and field.has_default():
default = field.default() if callable(field.default) else field.default
setattr(inst, field.dst_name, default)
continue
# Otherwise if the field is UNSET and has no default, skip conversion
if raw is UNSET or should_skip:
setattr(inst, field.dst_name, raw)
continue
value = field.try_convert(obj[field.src_name], self.client)
setattr(self, field.dst_name, value)
value = field.try_convert(raw, self.client)
setattr(inst, field.dst_name, value)
def update(self, other):
for name in six.iterkeys(self.__class__._fields):
value = getattr(other, name)
if value:
setattr(self, name, value)
def update(self, other, ignored=None):
for name in six.iterkeys(self._fields):
if ignored and name in ignored:
continue
if hasattr(other, name) and not getattr(other, name) is UNSET:
setattr(self, name, getattr(other, name))
# Clear cached properties
for name in dir(type(self)):
@ -259,8 +343,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
except:
pass
def to_dict(self):
return {k: getattr(self, k) for k in six.iterkeys(self.__class__._fields)}
def to_dict(self, ignore=None):
obj = {}
for name, field in six.iteritems(self.__class__._fields):
if ignore and name in ignore:
continue
if getattr(self, name) == UNSET:
continue
obj[name] = field.serialize(getattr(self, name), field)
return obj
@classmethod
def create(cls, client, data, **kwargs):
@ -269,8 +361,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
return inst
@classmethod
def create_map(cls, client, data):
return list(map(functools.partial(cls.create, client), data))
def create_map(cls, client, data, **kwargs):
return list(map(functools.partial(cls.create, client, **kwargs), data))
@classmethod
def create_hash(cls, client, key, data, **kwargs):
return HashMap({
get_item_by_path(item, key): item
for item in [
cls.create(client, item, **kwargs) for item in data]
})
@classmethod
def attach(cls, it, data):

102
disco/types/channel.py

@ -1,11 +1,12 @@
import six
from six.moves import map
from holster.enum import Enum
from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property, one_or_many, chunks
from disco.types.user import User
from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text
from disco.types.base import SlottedModel, Field, AutoDictField, snowflake, enum, text
from disco.types.permissions import Permissions, Permissible, PermissionValue
from disco.voice.client import VoiceClient
@ -33,7 +34,7 @@ class ChannelSubType(SlottedModel):
class PermissionOverwrite(ChannelSubType):
"""
A PermissionOverwrite for a :class:`Channel`
A PermissionOverwrite for a :class:`Channel`.
Attributes
----------
@ -48,8 +49,8 @@ class PermissionOverwrite(ChannelSubType):
"""
id = Field(snowflake)
type = Field(enum(PermissionOverwriteType))
allow = Field(PermissionValue)
deny = Field(PermissionValue)
allow = Field(PermissionValue, cast=int)
deny = Field(PermissionValue, cast=int)
channel_id = Field(snowflake)
@ -57,22 +58,29 @@ class PermissionOverwrite(ChannelSubType):
def create(cls, channel, entity, allow=0, deny=0):
from disco.types.guild import Role
type = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER
ptype = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER
return cls(
client=channel.client,
id=entity.id,
type=type,
type=ptype,
allow=allow,
deny=deny,
channel_id=channel.id
).save()
@property
def compiled(self):
value = PermissionValue()
value -= self.deny
value += self.allow
return value
def save(self):
self.client.api.channels_permissions_modify(self.channel_id,
self.id,
self.allow.value or 0,
self.deny.value or 0,
self.type.name)
self.id,
self.allow.value or 0,
self.deny.value or 0,
self.type.name)
return self
def delete(self):
@ -81,7 +89,7 @@ class PermissionOverwrite(ChannelSubType):
class Channel(SlottedModel, Permissible):
"""
Represents a Discord Channel
Represents a Discord Channel.
Attributes
----------
@ -111,18 +119,27 @@ class Channel(SlottedModel, Permissible):
last_message_id = Field(snowflake)
position = Field(int)
bitrate = Field(int)
recipients = Field(listof(User))
recipients = AutoDictField(User, 'id')
type = Field(enum(ChannelType))
overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites')
overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites')
def __init__(self, *args, **kwargs):
super(Channel, self).__init__(*args, **kwargs)
self.after_load()
def after_load(self):
# TODO: hackfix
self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self})
def __str__(self):
return u'#{}'.format(self.name)
def __repr__(self):
return u'<Channel {} ({})>'.format(self.id, self)
def get_permissions(self, user):
"""
Get the permissions a user has in the channel
Get the permissions a user has in the channel.
Returns
-------
@ -132,8 +149,8 @@ class Channel(SlottedModel, Permissible):
if not self.guild_id:
return Permissions.ADMINISTRATOR
member = self.guild.members.get(user.id)
base = self.guild.get_permissions(user)
member = self.guild.get_member(user)
base = self.guild.get_permissions(member)
for ow in six.itervalues(self.overwrites):
if ow.id != user.id and ow.id not in member.roles:
@ -144,48 +161,55 @@ class Channel(SlottedModel, Permissible):
return base
@property
def mention(self):
return '<#{}>'.format(self.id)
@property
def is_guild(self):
"""
Whether this channel belongs to a guild
Whether this channel belongs to a guild.
"""
return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE)
@property
def is_dm(self):
"""
Whether this channel is a DM (does not belong to a guild)
Whether this channel is a DM (does not belong to a guild).
"""
return self.type in (ChannelType.DM, ChannelType.GROUP_DM)
@property
def is_voice(self):
"""
Whether this channel supports voice
Whether this channel supports voice.
"""
return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM)
@property
def messages(self):
"""
a default :class:`MessageIterator` for the channel
a default :class:`MessageIterator` for the channel.
"""
return self.messages_iter()
@cached_property
def guild(self):
"""
Guild this channel belongs to (if relevant)
Guild this channel belongs to (if relevant).
"""
return self.client.state.guilds.get(self.guild_id)
def messages_iter(self, **kwargs):
"""
Creates a new :class:`MessageIterator` for the channel with the given
keyword arguments
keyword arguments.
"""
return MessageIterator(self.client, self, **kwargs)
def get_message(self, message):
return self.client.api.channels_messages_get(self.id, to_snowflake(message))
def get_invites(self):
"""
Returns
@ -220,9 +244,9 @@ class Channel(SlottedModel, Permissible):
def create_webhook(self, name=None, avatar=None):
return self.client.api.channels_webhooks_create(self.id, name, avatar)
def send_message(self, content, nonce=None, tts=False):
def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None):
"""
Send a message in this channel
Send a message in this channel.
Parameters
----------
@ -238,11 +262,11 @@ class Channel(SlottedModel, Permissible):
:class:`disco.types.message.Message`
The created message.
"""
return self.client.api.channels_messages_create(self.id, content, nonce, tts)
return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed)
def connect(self, *args, **kwargs):
"""
Connect to this channel over voice
Connect to this channel over voice.
"""
assert self.is_voice, 'Channel must support voice to connect'
vc = VoiceClient(self)
@ -275,17 +299,29 @@ class Channel(SlottedModel, Permissible):
List of messages (or message ids) to delete. All messages must originate
from this channel.
"""
messages = map(to_snowflake, messages)
message_ids = list(map(to_snowflake, messages))
if not messages:
if not message_ids:
return
if len(messages) <= 2:
if self.can(self.client.state.me, Permissions.MANAGE_MESSAGES) and len(messages) > 2:
for chunk in chunks(message_ids, 100):
self.client.api.channels_messages_delete_bulk(self.id, chunk)
else:
for msg in messages:
self.delete_message(msg)
else:
for chunk in chunks(messages, 100):
self.client.api.channels_messages_delete_bulk(self.id, chunk)
def delete(self):
assert (self.is_dm or self.guild.can(self.client.state.me, Permissions.MANAGE_GUILD)), 'Invalid Permissions'
self.client.api.channels_delete(self.id)
def close(self):
"""
Closes a DM channel. This is intended as a safer version of `delete`,
enforcing that the channel is actually a DM.
"""
assert self.is_dm, 'Cannot close non-DM channel'
self.delete()
class MessageIterator(object):
@ -329,7 +365,7 @@ class MessageIterator(object):
def fill(self):
"""
Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API
Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API.
"""
self._buffer = self.client.api.channels_messages_list(
self.channel.id,

129
disco/types/guild.py

@ -6,10 +6,13 @@ from disco.gateway.packets import OPCode
from disco.api.http import APIException
from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property
from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum
from disco.types.user import User
from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum, datetime
)
from disco.types.user import User, Presence
from disco.types.voice import VoiceState
from disco.types.channel import Channel
from disco.types.message import Emoji
from disco.types.permissions import PermissionValue, Permissions, Permissible
@ -18,21 +21,12 @@ VerificationLevel = Enum(
LOW=1,
MEDIUM=2,
HIGH=3,
EXTREME=4,
)
class GuildSubType(SlottedModel):
guild_id = Field(None)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class Emoji(GuildSubType):
class GuildEmoji(Emoji):
"""
An emoji object
An emoji object.
Attributes
----------
@ -48,15 +42,27 @@ class Emoji(GuildSubType):
Roles this emoji is attached to.
"""
id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text)
require_colons = Field(bool)
managed = Field(bool)
roles = Field(listof(snowflake))
roles = ListField(snowflake)
def __str__(self):
return u'<:{}:{}>'.format(self.name, self.id)
class Role(GuildSubType):
@property
def url(self):
return 'https://discordapp.com/api/emojis/{}.png'.format(self.id)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class Role(SlottedModel):
"""
A role object
A role object.
Attributes
----------
@ -76,6 +82,7 @@ class Role(GuildSubType):
The position of this role in the hierarchy.
"""
id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text)
hoist = Field(bool)
managed = Field(bool)
@ -84,6 +91,9 @@ class Role(GuildSubType):
position = Field(int)
mentionable = Field(bool)
def __str__(self):
return self.name
def delete(self):
self.guild.delete_role(self)
@ -94,10 +104,19 @@ class Role(GuildSubType):
def mention(self):
return '<@{}>'.format(self.id)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
class GuildMember(GuildSubType):
class GuildBan(SlottedModel):
user = Field(User)
reason = Field(str)
class GuildMember(SlottedModel):
"""
A GuildMember object
A GuildMember object.
Attributes
----------
@ -121,8 +140,18 @@ class GuildMember(GuildSubType):
nick = Field(text)
mute = Field(bool)
deaf = Field(bool)
joined_at = Field(str)
roles = Field(listof(snowflake))
joined_at = Field(datetime)
roles = ListField(snowflake)
def __str__(self):
return self.user.__str__()
@property
def name(self):
"""
The nickname of this user if set, otherwise their username
"""
return self.nick or self.user.username
def get_voice_state(self):
"""
@ -151,6 +180,12 @@ class GuildMember(GuildSubType):
"""
self.guild.create_ban(self, delete_message_days)
def unban(self):
"""
Unbans the member from the guild.
"""
self.guild.delete_ban(self)
def set_nickname(self, nickname=None):
"""
Sets the member's nickname (or clears it if None).
@ -160,11 +195,19 @@ class GuildMember(GuildSubType):
nickname : Optional[str]
The nickname (or none to reset) to set.
"""
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '')
if self.client.state.me.id == self.user.id:
self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '')
else:
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '')
def modify(self, **kwargs):
self.client.api.guilds_members_modify(self.guild.id, self.user.id, **kwargs)
def add_role(self, role):
roles = self.roles + [role.id]
self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles)
self.client.api.guilds_members_roles_add(self.guild.id, self.user.id, to_snowflake(role))
def remove_role(self, role):
self.client.api.guilds_members_roles_remove(self.guild.id, self.user.id, to_snowflake(role))
@cached_property
def owner(self):
@ -179,14 +222,22 @@ class GuildMember(GuildSubType):
@property
def id(self):
"""
Alias to the guild members user id
Alias to the guild members user id.
"""
return self.user.id
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@cached_property
def permissions(self):
return self.guild.get_permissions(self)
class Guild(SlottedModel, Permissible):
"""
A guild object
A guild object.
Attributes
----------
@ -222,7 +273,7 @@ class Guild(SlottedModel, Permissible):
All of the guild's channels.
roles : dict(snowflake, :class:`Role`)
All of the guild's roles.
emojis : dict(snowflake, :class:`Emoji`)
emojis : dict(snowflake, :class:`GuildEmoji`)
All of the guild's emojis.
voice_states : dict(str, :class:`disco.types.voice.VoiceState`)
All of the guild's voice states.
@ -239,12 +290,14 @@ class Guild(SlottedModel, Permissible):
embed_enabled = Field(bool)
verification_level = Field(enum(VerificationLevel))
mfa_level = Field(int)
features = Field(listof(str))
members = Field(dictof(GuildMember, key='id'))
channels = Field(dictof(Channel, key='id'))
roles = Field(dictof(Role, key='id'))
emojis = Field(dictof(Emoji, key='id'))
voice_states = Field(dictof(VoiceState, key='session_id'))
features = ListField(str)
members = AutoDictField(GuildMember, 'id')
channels = AutoDictField(Channel, 'id')
roles = AutoDictField(Role, 'id')
emojis = AutoDictField(GuildEmoji, 'id')
voice_states = AutoDictField(VoiceState, 'session_id')
member_count = Field(int)
presences = ListField(Presence)
synced = Field(bool, default=False)
@ -257,7 +310,7 @@ class Guild(SlottedModel, Permissible):
self.attach(six.itervalues(self.emojis), {'guild_id': self.id})
self.attach(six.itervalues(self.voice_states), {'guild_id': self.id})
def get_permissions(self, user):
def get_permissions(self, member):
"""
Get the permissions a user has in this guild.
@ -266,10 +319,13 @@ class Guild(SlottedModel, Permissible):
:class:`disco.types.permissions.PermissionValue`
Computed permission value for the user.
"""
if self.owner_id == user.id:
if not isinstance(member, GuildMember):
member = self.get_member(member)
# Owner has all permissions
if self.owner_id == member.id:
return PermissionValue(Permissions.ADMINISTRATOR)
member = self.get_member(user)
value = PermissionValue(self.roles.get(self.id).permissions)
for role in map(self.roles.get, member.roles):
@ -358,3 +414,6 @@ class Guild(SlottedModel, Permissible):
def create_ban(self, user, delete_message_days=0):
self.client.api.guilds_bans_create(self.id, to_snowflake(user), delete_message_days)
def create_channel(self, *args, **kwargs):
return self.client.api.guilds_channels_create(self.id, *args, **kwargs)

6
disco/types/invite.py

@ -1,4 +1,4 @@
from disco.types.base import SlottedModel, Field, lazy_datetime
from disco.types.base import SlottedModel, Field, datetime
from disco.types.user import User
from disco.types.guild import Guild
from disco.types.channel import Channel
@ -6,7 +6,7 @@ from disco.types.channel import Channel
class Invite(SlottedModel):
"""
An invite object
An invite object.
Attributes
----------
@ -37,7 +37,7 @@ class Invite(SlottedModel):
max_uses = Field(int)
uses = Field(int)
temporary = Field(bool)
created_at = Field(lazy_datetime)
created_at = Field(datetime)
@classmethod
def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False):

200
disco/types/message.py

@ -1,8 +1,14 @@
import re
import six
import functools
import unicodedata
from holster.enum import Enum
from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum
from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text,
datetime, enum
)
from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property
from disco.types.user import User
@ -19,6 +25,31 @@ MessageType = Enum(
)
class Emoji(SlottedModel):
id = Field(snowflake)
name = Field(text)
def __eq__(self, other):
if isinstance(other, Emoji):
return self.id == other.id and self.name == other.name
raise NotImplementedError
def to_string(self):
if self.id:
return '{}:{}'.format(self.name, self.id)
return self.name
class MessageReactionEmoji(Emoji):
pass
class MessageReaction(SlottedModel):
emoji = Field(MessageReactionEmoji)
count = Field(int)
me = Field(bool)
class MessageEmbedFooter(SlottedModel):
text = Field(text)
icon_url = Field(text)
@ -60,7 +91,7 @@ class MessageEmbedField(SlottedModel):
class MessageEmbed(SlottedModel):
"""
Message embed object
Message embed object.
Attributes
----------
@ -76,20 +107,38 @@ class MessageEmbed(SlottedModel):
title = Field(text)
type = Field(str, default='rich')
description = Field(text)
url = Field(str)
timestamp = Field(lazy_datetime)
url = Field(text)
timestamp = Field(datetime)
color = Field(int)
footer = Field(MessageEmbedFooter)
image = Field(MessageEmbedImage)
thumbnail = Field(MessageEmbedThumbnail)
video = Field(MessageEmbedVideo)
author = Field(MessageEmbedAuthor)
fields = Field(listof(MessageEmbedField))
fields = ListField(MessageEmbedField)
def set_footer(self, *args, **kwargs):
self.footer = MessageEmbedFooter(*args, **kwargs)
def set_image(self, *args, **kwargs):
self.image = MessageEmbedImage(*args, **kwargs)
def set_thumbnail(self, *args, **kwargs):
self.thumbnail = MessageEmbedThumbnail(*args, **kwargs)
def set_video(self, *args, **kwargs):
self.video = MessageEmbedVideo(*args, **kwargs)
def set_author(self, *args, **kwargs):
self.author = MessageEmbedAuthor(*args, **kwargs)
def add_field(self, *args, **kwargs):
self.fields.append(MessageEmbedField(*args, **kwargs))
class MessageAttachment(SlottedModel):
"""
Message attachment object
Message attachment object.
Attributes
----------
@ -110,8 +159,8 @@ class MessageAttachment(SlottedModel):
"""
id = Field(str)
filename = Field(text)
url = Field(str)
proxy_url = Field(str)
url = Field(text)
proxy_url = Field(text)
size = Field(int)
height = Field(int)
width = Field(int)
@ -161,15 +210,16 @@ class Message(SlottedModel):
author = Field(User)
content = Field(text)
nonce = Field(snowflake)
timestamp = Field(lazy_datetime)
edited_timestamp = Field(lazy_datetime)
timestamp = Field(datetime)
edited_timestamp = Field(datetime)
tts = Field(bool)
mention_everyone = Field(bool)
pinned = Field(bool)
mentions = Field(dictof(User, key='id'))
mention_roles = Field(listof(snowflake))
embeds = Field(listof(MessageEmbed))
attachments = Field(dictof(MessageAttachment, key='id'))
mentions = AutoDictField(User, 'id')
mention_roles = ListField(snowflake)
embeds = ListField(MessageEmbed)
attachments = AutoDictField(MessageAttachment, 'id')
reactions = ListField(MessageReaction)
def __str__(self):
return '<Message {} ({})>'.format(self.id, self.channel_id)
@ -213,7 +263,7 @@ class Message(SlottedModel):
def reply(self, *args, **kwargs):
"""
Reply to this message (proxys arguments to
:func:`disco.types.channel.Channel.send_message`)
:func:`disco.types.channel.Channel.send_message`).
Returns
-------
@ -222,9 +272,9 @@ class Message(SlottedModel):
"""
return self.channel.send_message(*args, **kwargs)
def edit(self, content):
def edit(self, *args, **kwargs):
"""
Edit this message
Edit this message.
Args
----
@ -236,7 +286,7 @@ class Message(SlottedModel):
:class:`Message`
The edited message object.
"""
return self.client.api.channels_messages_modify(self.channel_id, self.id, content)
return self.client.api.channels_messages_modify(self.channel_id, self.id, *args, **kwargs)
def delete(self):
"""
@ -249,6 +299,42 @@ class Message(SlottedModel):
"""
return self.client.api.channels_messages_delete(self.channel_id, self.id)
def get_reactors(self, emoji):
"""
Returns an list of users who reacted to this message with the given emoji.
Returns
-------
list(:class:`User`)
The users who reacted.
"""
return self.client.api.channels_messages_reactions_get(
self.channel_id,
self.id,
emoji
)
def create_reaction(self, emoji):
if isinstance(emoji, Emoji):
emoji = emoji.to_string()
self.client.api.channels_messages_reactions_create(
self.channel_id,
self.id,
emoji)
def delete_reaction(self, emoji, user=None):
if isinstance(emoji, Emoji):
emoji = emoji.to_string()
if user:
user = to_snowflake(user)
self.client.api.channels_messages_reactions_delete(
self.channel_id,
self.id,
emoji,
user)
def is_mentioned(self, entity):
"""
Returns
@ -256,22 +342,37 @@ class Message(SlottedModel):
bool
Whether the give entity was mentioned.
"""
id = to_snowflake(entity)
return id in self.mentions or id in self.mention_roles
entity = to_snowflake(entity)
return entity in self.mentions or entity in self.mention_roles
@cached_property
def without_mentions(self):
def without_mentions(self, valid_only=False):
"""
Returns
-------
str
the message contents with all valid mentions removed.
the message contents with all mentions removed.
"""
return self.replace_mentions(
lambda u: '',
lambda r: '')
lambda r: '',
lambda c: '',
nonexistant=not valid_only)
@cached_property
def with_proper_mentions(self):
def replace_user(u):
return u'@' + six.text_type(u)
def replace_role(r):
return u'@' + six.text_type(r)
def replace_channel(c):
return six.text_type(c)
def replace_mentions(self, user_replace, role_replace):
return self.replace_mentions(replace_user, replace_role, replace_channel)
def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False):
"""
Replaces user and role mentions with the result of a given lambda/function.
@ -289,39 +390,55 @@ class Message(SlottedModel):
str
The message contents with all valid mentions replaced.
"""
if not self.mentions and not self.mention_roles:
return
def replace(getter, func, match):
oid = int(match.group(2))
obj = getter(oid)
if obj or nonexistant:
return func(obj or oid) or match.group(0)
return match.group(0)
content = self.content
if user_replace:
replace_user = functools.partial(replace, self.mentions.get, user_replace)
content = re.sub('(<@!?([0-9]+)>)', replace_user, content)
if role_replace:
replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace)
content = re.sub('(<@&([0-9]+)>)', replace_role, content)
def replace(match):
id = match.group(0)
if id in self.mention_roles:
return role_replace(id)
else:
return user_replace(self.mentions.get(id))
if channel_replace:
replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace)
content = re.sub('(<#([0-9]+)>)', replace_channel, content)
return re.sub('<@!?([0-9]+)>', replace, self.content)
return content
class MessageTable(object):
def __init__(self, sep=' | ', codeblock=True, header_break=True):
def __init__(self, sep=' | ', codeblock=True, header_break=True, language=None):
self.header = []
self.entries = []
self.size_index = {}
self.sep = sep
self.codeblock = codeblock
self.header_break = header_break
self.language = language
def recalculate_size_index(self, cols):
for idx, col in enumerate(cols):
if idx not in self.size_index or len(col) > self.size_index[idx]:
self.size_index[idx] = len(col)
size = len(unicodedata.normalize('NFC', col))
if idx not in self.size_index or size > self.size_index[idx]:
self.size_index[idx] = size
def set_header(self, *args):
args = list(map(six.text_type, args))
self.header = args
self.recalculate_size_index(args)
def add(self, *args):
args = list(map(str, args))
args = list(map(six.text_type, args))
self.entries.append(args)
self.recalculate_size_index(args)
@ -329,22 +446,23 @@ class MessageTable(object):
data = self.sep.lstrip()
for idx, col in enumerate(cols):
padding = ' ' * ((self.size_index[idx] - len(col)))
padding = ' ' * (self.size_index[idx] - len(col))
data += col + padding + self.sep
return data.rstrip()
def compile(self):
data = []
data.append(self.compile_one(self.header))
if self.header:
data = [self.compile_one(self.header)]
if self.header_break:
if self.header and self.header_break:
data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1))
for row in self.entries:
data.append(self.compile_one(row))
if self.codeblock:
return '```' + '\n'.join(data) + '```'
return '```{}'.format(self.language if self.language else '') + '\n'.join(data) + '```'
return '\n'.join(data)

12
disco/types/permissions.py

@ -76,13 +76,13 @@ class PermissionValue(object):
return self.sub(other)
def __getattribute__(self, name):
if name in Permissions.attrs:
if name in Permissions.keys_:
return (self.value & Permissions[name].value) == Permissions[name].value
else:
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
if name not in Permissions.attrs:
if name not in Permissions.keys_:
return super(PermissionValue, self).__setattr__(name, value)
if value:
@ -90,9 +90,12 @@ class PermissionValue(object):
else:
self.value &= ~Permissions[name].value
def __int__(self):
return self.value
def to_dict(self):
return {
k: getattr(self, k) for k in Permissions.attrs
k: getattr(self, k) for k in Permissions.keys_
}
@classmethod
@ -107,6 +110,9 @@ class PermissionValue(object):
class Permissible(object):
__slots__ = []
def get_permissions(self):
raise NotImplementedError
def can(self, user, *args):
perms = self.get_permissions(user)
return perms.administrator or perms.can(*args)

40
disco/types/user.py

@ -2,30 +2,54 @@ from holster.enum import Enum
from disco.types.base import SlottedModel, Field, snowflake, text, binary, with_equality, with_hash
DefaultAvatars = Enum(
BLURPLE=0,
GREY=1,
GREEN=2,
ORANGE=3,
RED=4,
)
class User(SlottedModel, with_equality('id'), with_hash('id')):
id = Field(snowflake)
username = Field(text)
avatar = Field(binary)
discriminator = Field(str)
bot = Field(bool)
bot = Field(bool, default=False)
verified = Field(bool)
email = Field(str)
presence = Field(None)
def get_avatar_url(self, fmt='webp', size=1024):
if not self.avatar:
return 'https://cdn.discordapp.com/embed/avatars/{}.png'.format(self.default_avatar.value)
return 'https://cdn.discordapp.com/avatars/{}/{}.{}?size={}'.format(
self.id,
self.avatar,
fmt,
size
)
@property
def default_avatar(self):
return DefaultAvatars[int(self.discriminator) % len(DefaultAvatars.attrs)]
@property
def avatar_url(self):
return self.get_avatar_url()
@property
def mention(self):
return '<@{}>'.format(self.id)
def to_string(self):
return '{}#{}'.format(self.username, self.discriminator)
def __str__(self):
return '<User {} ({})>'.format(self.id, self.to_string())
return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4))
def on_create(self):
self.client.state.users[self.id] = self
def __repr__(self):
return u'<User {} ({})>'.format(self.id, self)
GameType = Enum(
@ -49,6 +73,6 @@ class Game(SlottedModel):
class Presence(SlottedModel):
user = Field(User)
user = Field(User, alias='user', ignore_dump=['presence'])
game = Field(Game)
status = Field(Status)

2
disco/types/voice.py

@ -17,7 +17,7 @@ class VoiceState(SlottedModel):
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@cached_property
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)

6
disco/types/webhook.py

@ -32,12 +32,14 @@ class Webhook(SlottedModel):
else:
return self.client.api.webhooks_modify(self.id, name, avatar)
def execute(self, content=None, username=None, avatar_url=None, tts=False, file=None, embeds=[], wait=False):
def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False):
# TODO: support file stuff properly
return self.client.api.webhooks_token_execute(self.id, self.token, {
'content': content,
'username': username,
'avatar_url': avatar_url,
'tts': tts,
'file': file,
'file': fobj,
'embeds': [i.to_dict() for i in embeds],
}, wait)

2
disco/util/config.py

@ -29,7 +29,7 @@ class Config(object):
return inst
def from_prefix(self, prefix):
prefix = prefix + '_'
prefix += '_'
obj = {}
for k, v in six.iteritems(self.__dict__):

4
disco/util/hashmap.py

@ -45,12 +45,12 @@ class HashMap(UserDict):
def filter(self, predicate):
if not callable(predicate):
raise TypeError('predicate must be callable')
return filter(self.values(), predicate)
return filter(predicate, self.values())
def map(self, predicate):
if not callable(predicate):
raise TypeError('predicate must be callable')
return map(self.values(), predicate)
return map(predicate, self.values())
class DefaultHashMap(defaultdict, HashMap):

3
disco/util/limiter.py

@ -17,7 +17,8 @@ class SimpleLimiter(object):
gevent.sleep(self.reset_at - time.time())
self.count = 0
self.reset_at = 0
self.event.set()
if self.event:
self.event.set()
self.event = None
def check(self):

35
disco/util/logging.py

@ -3,15 +3,28 @@ from __future__ import absolute_import
import logging
LEVEL_OVERRIDES = {
'requests': logging.WARNING
}
LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'
def setup_logging(**kwargs):
kwargs.setdefault('format', LOG_FORMAT)
logging.basicConfig(**kwargs)
for logger, level in LEVEL_OVERRIDES.items():
logging.getLogger(logger).setLevel(level)
class LoggingClass(object):
def __init__(self):
self.log = logging.getLogger(self.__class__.__name__)
def log_on_error(self, msg, f):
def _f(*args, **kwargs):
try:
return f(*args, **kwargs)
except:
self.log.exception(msg)
raise
return _f
__slots__ = ['_log']
@property
def log(self):
try:
return self._log
except AttributeError:
self._log = logging.getLogger(self.__class__.__name__)
return self._log

36
disco/util/serializer.py

@ -1,3 +1,5 @@
import six
import types
class Serializer(object):
@ -36,3 +38,37 @@ class Serializer(object):
def dumps(cls, fmt, raw):
_, dumps = getattr(cls, fmt)()
return dumps(raw)
def dump_cell(cell):
return cell.cell_contents
def load_cell(cell):
if six.PY3:
return (lambda y: cell).__closure__[0]
else:
return (lambda y: cell).func_closure[0]
def dump_function(func):
if six.PY3:
return (
func.__code__,
func.__name__,
func.__defaults__,
list(map(dump_cell, func.__closure__)) if func.__closure__ else [],
)
else:
return (
func.func_code,
func.func_name,
func.func_defaults,
list(map(dump_cell, func.func_closure)) if func.func_closure else [],
)
def load_function(args):
code, name, defaults, closure = args
closure = tuple(map(load_cell, closure))
return types.FunctionType(code, globals(), name, defaults, closure)

2
disco/util/snowflake.py

@ -17,7 +17,7 @@ def to_unix(snowflake):
def to_unix_ms(snowflake):
return ((int(snowflake) >> 22) + DISCORD_EPOCH)
return (int(snowflake) >> 22) + DISCORD_EPOCH
def to_snowflake(i):

2
disco/util/token.py

@ -5,6 +5,6 @@ TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}')
def is_valid_token(token):
"""
Validates a Discord authentication token, returning true if valid
Validates a Discord authentication token, returning true if valid.
"""
return bool(TOKEN_RE.match(token))

7
disco/voice/client.py

@ -106,6 +106,7 @@ class VoiceClient(LoggingClass):
self.endpoint = None
self.ssrc = None
self.port = None
self.udp = None
self.update_listener = None
@ -158,7 +159,7 @@ class VoiceClient(LoggingClass):
}
})
def on_voice_sdp(self, data):
def on_voice_sdp(self, _):
# Toggle speaking state so clients learn of our SSRC
self.set_speaking(True)
self.set_speaking(False)
@ -187,11 +188,10 @@ class VoiceClient(LoggingClass):
def on_message(self, msg):
try:
data = self.encoder.decode(msg)
self.packets.emit(VoiceOPCode[data['op']], data['d'])
except:
self.log.exception('Failed to parse voice gateway message: ')
self.packets.emit(VoiceOPCode[data['op']], data['d'])
def on_error(self, err):
# TODO
self.log.warning('Voice websocket error: {}'.format(err))
@ -205,6 +205,7 @@ class VoiceClient(LoggingClass):
})
def on_close(self, code, error):
# TODO
self.log.warning('Voice websocket disconnected (%s, %s)', code, error)
if self.state == VoiceState.CONNECTED:

127
disco/voice/opus.py

@ -1,8 +1,15 @@
import sys
import array
import gevent
import ctypes
import ctypes.util
try:
from cStringIO import cStringIO as StringIO
except:
from StringIO import StringIO
from gevent.queue import Queue
from holster.enum import Enum
from disco.util.logging import LoggingClass
@ -43,12 +50,12 @@ class BaseOpus(LoggingClass):
for name, item in methods.items():
func = getattr(self.lib, name)
if item[1]:
func.argtypes = item[1]
if item[0]:
func.argtypes = item[0]
func.restype = item[2]
func.restype = item[1]
setattr(self, name.replace('opus_', ''), func)
setattr(self, name, func)
@staticmethod
def find_library():
@ -83,7 +90,7 @@ class OpusEncoder(BaseOpus):
}
def __init__(self, sampling, channels, application=Application.AUDIO, library_path=None):
super(OpusDecoder, self).__init__(library_path)
super(OpusEncoder, self).__init__(library_path)
self.sampling_rate = sampling
self.channels = channels
self.application = application
@ -94,10 +101,32 @@ class OpusEncoder(BaseOpus):
self.frame_size = self.samples_per_frame * self.sample_size
self.inst = self.create()
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
def set_bitrate(self, kbps):
kbps = min(128, max(16, int(kbps)))
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_BITRATE), kbps * 1024)
if ret < 0:
raise Exception('Failed to set bitrate to {}: {}'.format(kbps, ret))
def set_fec(self, value):
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_FEC), int(value))
if ret < 0:
raise Exception('Failed to set FEC to {}: {}'.format(value, ret))
def set_expected_packet_loss_percent(self, perc):
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_PLP), min(100, max(0, int(perc * 100))))
if ret < 0:
raise Exception('Failed to set PLP to {}: {}'.format(perc, ret))
def create(self):
ret = ctypes.c_int()
result = self.encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret))
result = self.opus_encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret))
if ret.value != 0:
raise Exception('Failed to create opus encoder: {}'.format(ret.value))
@ -106,7 +135,7 @@ class OpusEncoder(BaseOpus):
def __del__(self):
if self.inst:
self.encoder_destroy(self.inst)
self.opus_encoder_destroy(self.inst)
self.inst = None
def encode(self, pcm, frame_size):
@ -114,12 +143,92 @@ class OpusEncoder(BaseOpus):
pcm = ctypes.cast(pcm, c_int16_ptr)
data = (ctypes.c_char * max_data_bytes)()
ret = self.encode(self.inst, pcm, frame_size, data, max_data_bytes)
ret = self.opus_encode(self.inst, pcm, frame_size, data, max_data_bytes)
if ret < 0:
raise Exception('Failed to encode: {}'.format(ret))
return array.array('b', data[:ret]).tobytes()
# TODO: py3
return array.array('b', data[:ret]).tostring()
class OpusDecoder(BaseOpus):
pass
class BufferedOpusEncoder(OpusEncoder):
def __init__(self, data, *args, **kwargs):
self.data = StringIO(data)
self.frames = Queue(kwargs.pop('queue_size', 4096))
super(BufferedOpusEncoder, self).__init__(*args, **kwargs)
gevent.spawn(self._encoder_loop)
def _encoder_loop(self):
while self.data:
raw = self.data.read(self.frame_size)
if len(raw) < self.frame_size:
break
self.frames.put(self.encode(raw, self.samples_per_frame))
gevent.idle()
self.data = None
def have_frame(self):
return self.data or not self.frames.empty()
def next_frame(self):
return self.frames.get()
class GIPCBufferedOpusEncoder(OpusEncoder):
FIN = 1
def __init__(self, data, *args, **kwargs):
import gipc
self.data = StringIO(data)
self.parent_pipe, self.child_pipe = gipc.pipe(duplex=True)
self.frames = Queue(kwargs.pop('queue_size', 4096))
super(GIPCBufferedOpusEncoder, self).__init__(*args, **kwargs)
gipc.start_process(target=self._encoder_loop, args=(self.child_pipe, (args, kwargs)))
gevent.spawn(self._writer)
gevent.spawn(self._reader)
def _reader(self):
while True:
data = self.parent_pipe.get()
if data == self.FIN:
return
self.frames.put(data)
self.parent_pipe = None
def _writer(self):
while self.data:
raw = self.data.read(self.frame_size)
if len(raw) < self.frame_size:
break
self.parent_pipe.put(raw)
gevent.idle()
self.parent_pipe.put(self.FIN)
def have_frame(self):
return self.parent_pipe
def next_frame(self):
return self.frames.get()
@classmethod
def _encoder_loop(cls, pipe, (args, kwargs)):
encoder = OpusEncoder(*args, **kwargs)
while True:
data = pipe.get()
if data == cls.FIN:
pipe.put(cls.FIN)
return
pipe.put(encoder.encode(data, encoder.samples_per_frame))

49
disco/voice/player.py

@ -1,18 +1,54 @@
import time
import gevent
import struct
import time
import subprocess
from six.moves import queue
from disco.voice.client import VoiceState
from disco.voice.opus import BufferedOpusEncoder, GIPCBufferedOpusEncoder
class BaseFFmpegPlayable(object):
def __init__(self, source='-', command='avconv', sampling_rate=48000, channels=2, **kwargs):
args = [command, '-i', source, '-f', 's16le', '-ar', str(sampling_rate), '-ac', str(channels), '-loglevel', 'warning', 'pipe:1']
self.proc = subprocess.Popen(args, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
data, _ = self.proc.communicate()
super(BaseFFmpegPlayable, self).__init__(data, sampling_rate, channels, **kwargs)
class FFmpegPlayable(BaseFFmpegPlayable, BufferedOpusEncoder):
pass
class GIPCFFmpegPlayable(BaseFFmpegPlayable, GIPCBufferedOpusEncoder):
pass
class OpusItem(object):
def __init__(self, frame_length=20, channels=2):
def create_youtube_dl_playable(url, cls=FFmpegPlayable, *args, **kwargs):
import youtube_dl
ydl = youtube_dl.YoutubeDL({'format': 'webm[abr>0]/bestaudio/best'})
info = ydl.extract_info(url, download=False)
if 'entries' in info:
info = info['entries'][0]
return cls(info['url'], *args, **kwargs), info
class OpusPlayable(object):
"""
Represents a Playable item which is a cached set of Opus-encoded bytes.
"""
def __init__(self, sampling_rate=48000, frame_length=20, channels=2):
self.frames = []
self.idx = 0
self.frame_length = 20
self.sampling_rate = sampling_rate
self.frame_length = frame_length
self.channels = channels
self.sample_size = int(self.sampling_rate / 1000 * self.frame_length)
@classmethod
def from_raw_file(cls, path):
@ -58,6 +94,7 @@ class Player(object):
def play(self, item):
start = time.time()
loops = 0
timestamp = 0
while True:
loops += 1
@ -76,13 +113,15 @@ class Player(object):
if not item.have_frame():
return
self.client.send_frame(item.next_frame())
self.client.send_frame(item.next_frame(), loops, timestamp)
timestamp += item.samples_per_frame
next_time = start + 0.02 * loops
delay = max(0, 0.02 + (next_time - time.time()))
gevent.sleep(delay)
def run(self):
self.client.set_speaking(True)
while self.playing:
self.play(self.queue.get())
@ -90,4 +129,6 @@ class Player(object):
self.playing = False
self.complete.set()
return
self.client.set_speaking(False)
self.disconnect()

9
examples/music.py

@ -1,15 +1,16 @@
from disco.bot import Plugin
from disco.bot.command import CommandError
from disco.voice.client import Player, OpusItem, VoiceException
from disco.voice.player import Player, create_youtube_dl_playable
from disco.voice.client import VoiceException
def download(url):
return OpusItem.from_raw_file('test.dca')
return create_youtube_dl_playable(url)[0]
class MusicPlugin(Plugin):
def load(self):
super(MusicPlugin, self).load()
def load(self, ctx):
super(MusicPlugin, self).load(ctx)
self.guilds = {}
@Plugin.command('join')

2
requirements.txt

@ -1,5 +1,5 @@
gevent==1.1.2
holster==1.0.7
holster==1.0.11
inflection==0.3.1
requests==2.11.1
six==1.10.0

Loading…
Cancel
Save