Browse Source

Merge branch 'master' into feature/docs-redux

pull/20/head
Andrei 8 years ago
parent
commit
c8289bf5fe
  1. 3
      .gitignore
  2. 13
      .travis.yml
  3. 2
      disco/__init__.py
  4. 66
      disco/api/client.py
  5. 8
      disco/api/http.py
  6. 57
      disco/bot/bot.py
  7. 26
      disco/bot/command.py
  8. 6
      disco/bot/parser.py
  9. 11
      disco/bot/plugin.py
  10. 15
      disco/bot/providers/__init__.py
  11. 134
      disco/bot/providers/base.py
  12. 54
      disco/bot/providers/disk.py
  13. 5
      disco/bot/providers/memory.py
  14. 48
      disco/bot/providers/redis.py
  15. 52
      disco/bot/providers/rocksdb.py
  16. 97
      disco/bot/storage.py
  17. 44
      disco/client.py
  18. 10
      disco/gateway/client.py
  19. 11
      disco/gateway/events.py
  20. 68
      disco/state.py
  21. 20
      disco/types/base.py
  22. 15
      disco/types/channel.py
  23. 37
      disco/types/guild.py
  24. 2
      disco/types/invite.py
  25. 7
      disco/types/message.py
  26. 4
      disco/types/permissions.py
  27. 3
      disco/types/user.py
  28. 23
      disco/types/webhook.py
  29. 71
      disco/util/chains.py
  30. 9
      disco/util/logging.py
  31. 32
      disco/util/sanitize.py
  32. 9
      disco/util/snowflake.py
  33. 3
      disco/voice/__init__.py
  34. 116
      disco/voice/client.py
  35. 149
      disco/voice/opus.py
  36. 354
      disco/voice/playable.py
  37. 126
      disco/voice/player.py
  38. 4
      docs/SUMMARY.md
  39. 3
      docs/types/MESSAGE.md
  40. 2
      examples/basic_plugin.py
  41. 52
      examples/music.py
  42. 33
      examples/storage.py
  43. 8
      requirements.txt
  44. 15
      setup.py
  45. 47
      tests/test_bot.py
  46. 21
      tests/test_channel.py
  47. 32
      tests/test_embeds.py
  48. 42
      tests/test_imports.py

3
.gitignore

@ -3,6 +3,7 @@ dist/
disco*.egg-info/
docs/_build
storage.db
_book/
node_modules/
storage.json
*.dca

13
.travis.yml

@ -0,0 +1,13 @@
language: python
cache: pip
python:
- '2.7'
- '3.3'
- '3.4'
- '3.5'
- '3.6'
- 'nightly'
script: 'python setup.py test'

2
disco/__init__.py

@ -1 +1 @@
VERSION = '0.0.7'
VERSION = '0.0.8'

66
disco/api/client.py

@ -1,8 +1,10 @@
import six
import json
import warnings
from disco.api.http import Routes, HTTPClient
from disco.util.logging import LoggingClass
from disco.util.sanitize import S
from disco.types.user import User
from disco.types.message import Message
@ -88,29 +90,56 @@ class APIClient(LoggingClass):
r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message))
return Message.create(self.client, r.json())
def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None):
def channels_messages_create(self, channel, content=None, nonce=None, tts=False,
attachment=None, attachments=[], embed=None, sanitize=False):
payload = {
'content': content,
'nonce': nonce,
'tts': tts,
}
if attachment:
attachments = [attachment]
warnings.warn(
'attachment kwarg has been deprecated, switch to using attachments with a list',
DeprecationWarning)
if content:
if sanitize:
content = S(content)
payload['content'] = content
if embed:
payload['embed'] = embed.to_dict()
if attachment:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data={'payload_json': json.dumps(payload)}, files={
'file': (attachment[0], attachment[1])
})
if attachments:
if len(attachments) > 1:
files = {
'file{}'.format(idx): tuple(i) for idx, i in enumerate(attachments)
}
else:
files = {
'file': tuple(attachments[0]),
}
r = self.http(
Routes.CHANNELS_MESSAGES_CREATE,
dict(channel=channel),
data={'payload_json': json.dumps(payload)},
files=files
)
else:
r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload)
return Message.create(self.client, r.json())
def channels_messages_modify(self, channel, message, content, embed=None):
payload = {
'content': content,
}
def channels_messages_modify(self, channel, message, content=None, embed=None, sanitize=False):
payload = {}
if content:
if sanitize:
content = S(content)
payload['content'] = content
if embed:
payload['embed'] = embed.to_dict()
@ -285,6 +314,10 @@ class APIClient(LoggingClass):
def guilds_roles_delete(self, guild, role):
self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role))
def guilds_invites_list(self, guild):
r = self.http(Routes.GUILDS_INVITES_LIST, dict(guild=guild))
return Invite.create_map(self.client, r.json())
def guilds_webhooks_list(self, guild):
r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild))
return Webhook.create_map(self.client, r.json())
@ -295,11 +328,11 @@ class APIClient(LoggingClass):
def guilds_emojis_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs)
return GuildEmoji.create(self.client, r.json())
return GuildEmoji.create(self.client, r.json(), guild_id=guild)
def guilds_emojis_modify(self, guild, emoji, **kwargs):
r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs)
return GuildEmoji.create(self.client, r.json())
return GuildEmoji.create(self.client, r.json(), guild_id=guild)
def guilds_emojis_delete(self, guild, emoji):
self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji))
@ -311,6 +344,15 @@ class APIClient(LoggingClass):
r = self.http(Routes.USERS_ME_PATCH, json=payload)
return User.create(self.client, r.json())
def users_me_guilds_delete(self, guild):
self.http(Routes.USERS_ME_GUILDS_DELETE, dict(guild=guild))
def users_me_dms_create(self, recipient_id):
r = self.http(Routes.USERS_ME_DMS_CREATE, json={
'recipient_id': recipient_id,
})
return Channel.create(self.client, r.json())
def invites_get(self, invite):
r = self.http(Routes.INVITES_GET, dict(invite=invite))
return Invite.create(self.client, r.json())

8
disco/api/http.py

@ -108,7 +108,7 @@ class Routes(object):
USERS_ME_GET = (HTTPMethod.GET, USERS + '/@me')
USERS_ME_PATCH = (HTTPMethod.PATCH, USERS + '/@me')
USERS_ME_GUILDS_LIST = (HTTPMethod.GET, USERS + '/@me/guilds')
USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}')
USERS_ME_GUILDS_DELETE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}')
USERS_ME_DMS_LIST = (HTTPMethod.GET, USERS + '/@me/channels')
USERS_ME_DMS_CREATE = (HTTPMethod.POST, USERS + '/@me/channels')
USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, USERS + '/@me/connections')
@ -176,7 +176,7 @@ class HTTPClient(LoggingClass):
A simple HTTP client which wraps the requests library, adding support for
Discords rate-limit headers, authorization, and request/response validation.
"""
BASE_URL = 'https://discordapp.com/api/v6'
BASE_URL = 'https://discordapp.com/api/v7'
MAX_RETRIES = 5
def __init__(self, token):
@ -189,13 +189,15 @@ class HTTPClient(LoggingClass):
self.limiter = RateLimiter()
self.headers = {
'Authorization': 'Bot ' + token,
'User-Agent': 'DiscordBot (https://github.com/b1naryth1ef/disco {}) Python/{} requests/{}'.format(
disco_version,
py_version,
requests_version),
}
if token:
self.headers['Authorization'] = 'Bot ' + token
def __call__(self, route, args=None, **kwargs):
return self.call(route, args, **kwargs)

57
disco/bot/bot.py

@ -65,6 +65,7 @@ class BotConfig(Config):
The directory plugin configuration is located within.
"""
levels = {}
plugins = []
plugin_config = {}
commands_enabled = True
@ -81,12 +82,13 @@ class BotConfig(Config):
commands_group_abbrev = True
plugin_config_provider = None
plugin_config_format = 'yaml'
plugin_config_format = 'json'
plugin_config_dir = 'config'
storage_enabled = True
storage_provider = 'memory'
storage_config = {}
storage_fsync = True
storage_serializer = 'json'
storage_path = 'storage.json'
class Bot(LoggingClass):
@ -195,35 +197,47 @@ class Bot(LoggingClass):
Called when a plugin is loaded/unloaded to recompute internal state.
"""
if self.config.commands_group_abbrev:
self.compute_group_abbrev()
groups = set(command.group for command in self.commands if command.group)
self.group_abbrev = self.compute_group_abbrev(groups)
self.compute_command_matches_re()
def compute_group_abbrev(self):
def compute_group_abbrev(self, groups):
"""
Computes all possible abbreviations for a command grouping.
"""
self.group_abbrev = {}
groups = set(command.group for command in self.commands if command.group)
# For the first pass, we just want to compute each groups possible
# abbreviations that don't conflict with eachother.
possible = {}
for group in groups:
grp = group
while grp:
# If the group already exists, means someone else thought they
# could use it so we need yank it from them (and not use it)
if grp in list(six.itervalues(self.group_abbrev)):
self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp}
for index in range(len(group)):
current = group[:index]
if current in possible:
possible[current] = None
else:
self.group_abbrev[group] = grp
possible[current] = group
grp = grp[:-1]
# Now, we want to compute the actual shortest abbreivation out of the
# possible ones
result = {}
for abbrev, group in six.iteritems(possible):
if not group:
continue
if group in result:
if len(abbrev) < len(result[group]):
result[group] = abbrev
else:
result[group] = abbrev
return result
def compute_command_matches_re(self):
"""
Computes a single regex which matches all possible command combinations.
"""
commands = list(self.commands)
re_str = '|'.join(command.regex for command in commands)
re_str = '|'.join(command.regex(grouped=False) for command in commands)
if re_str:
self.command_matches_re = re.compile(re_str, re.I)
else:
@ -267,7 +281,10 @@ class Bot(LoggingClass):
if msg.guild:
member = msg.guild.get_member(self.client.state.me)
if member:
# If nickname is set, filter both the normal and nick mentions
if member.nick:
content = content.replace(member.mention, '', 1)
content = content.replace(member.user.mention, '', 1)
else:
content = content.replace(self.client.state.me.mention, '', 1)
elif mention_everyone:
@ -355,10 +372,10 @@ class Bot(LoggingClass):
if event.message.author.id == self.client.state.me.id:
return
if self.config.commands_allow_edit:
self.last_message_cache[event.message.channel_id] = (event.message, False)
result = self.handle_message(event.message)
self.handle_message(event.message)
if self.config.commands_allow_edit:
self.last_message_cache[event.message.channel_id] = (event.message, result)
def on_message_update(self, event):
if self.config.commands_allow_edit:

26
disco/bot/command.py

@ -6,6 +6,7 @@ from disco.bot.parser import ArgumentSet, ArgumentError
from disco.util.functional import cached_property
ARGS_REGEX = '(?: ((?:\n|.)*)$|$)'
ARGS_UNGROUPED_REGEX = '(?: (?:\n|.)*$|$)'
USER_MENTION_RE = re.compile('<@!?([0-9]+)>')
ROLE_MENTION_RE = re.compile('<@&([0-9]+)>')
@ -44,11 +45,11 @@ class CommandEvent(object):
self.command = command
self.msg = msg
self.match = match
self.name = self.match.group(0)
self.name = self.match.group(1).strip()
self.args = []
if self.match.group(1):
self.args = [i for i in self.match.group(1).strip().split(' ') if i]
if self.match.group(2):
self.args = [i for i in self.match.group(2).strip().split(' ') if i]
@property
def codeblock(self):
@ -140,6 +141,10 @@ class Command(object):
self.update(*args, **kwargs)
@property
def name(self):
return self.triggers[0]
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
@ -218,10 +223,9 @@ class Command(object):
"""
A compiled version of this command's regex.
"""
return re.compile(self.regex, re.I)
return re.compile(self.regex(), re.I)
@property
def regex(self):
def regex(self, grouped=True):
"""
The regex string that defines/triggers this command.
"""
@ -231,10 +235,13 @@ class Command(object):
group = ''
if self.group:
if self.group in self.plugin.bot.group_abbrev:
group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group))
group = '(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group))
else:
group = self.group + ' '
return '^{}(?:{})'.format(group, '|'.join(self.triggers)) + ARGS_REGEX
return ('^{}({})' if grouped else '^{}(?:{})').format(
group,
'|'.join(self.triggers)
) + (ARGS_REGEX if grouped else ARGS_UNGROUPED_REGEX)
def execute(self, event):
"""
@ -247,9 +254,10 @@ class Command(object):
Whether this command was successful
"""
if len(event.args) < self.args.required_length:
raise CommandError('{} requires {} arguments (passed {})'.format(
raise CommandError(u'Command {} requires {} arguments (`{}`) passed {}'.format(
event.name,
self.args.required_length,
self.raw_args,
len(event.args)
))

6
disco/bot/parser.py

@ -15,12 +15,6 @@ TYPE_MAP = {
'snowflake': lambda ctx, data: int(data),
}
try:
import dateparser
TYPE_MAP['duration'] = lambda ctx, data: dateparser.parse(data, settings={'TIMEZONE': 'UTC'})
except ImportError:
pass
def to_bool(ctx, data):
if data in BOOL_OPTS:

11
disco/bot/plugin.py

@ -209,15 +209,22 @@ class Plugin(LoggingClass, PluginDeco):
def handle_exception(self, greenlet, event):
pass
def wait_for_event(self, event_name, **kwargs):
def wait_for_event(self, event_name, conditional=None, **kwargs):
result = AsyncResult()
listener = None
def _event_callback(event):
for k, v in kwargs.items():
if getattr(event, k) != v:
obj = event
for inst in k.split('__'):
obj = getattr(obj, inst)
if obj != v:
break
else:
if conditional and not conditional(event):
return
listener.remove()
return result.set(event)

15
disco/bot/providers/__init__.py

@ -1,15 +0,0 @@
import inspect
import importlib
from .base import BaseProvider
def load_provider(name):
try:
mod = importlib.import_module('disco.bot.providers.' + name)
except ImportError:
mod = importlib.import_module(name)
for entry in filter(inspect.isclass, map(lambda i: getattr(mod, i), dir(mod))):
if issubclass(entry, BaseProvider) and entry != BaseProvider:
return entry

134
disco/bot/providers/base.py

@ -1,134 +0,0 @@
import six
import pickle
from six.moves import map, UserDict
ROOT_SENTINEL = u'\u200B'
SEP_SENTINEL = u'\u200D'
OBJ_SENTINEL = u'\u200C'
CAST_SENTINEL = u'\u24EA'
def join_key(*args):
nargs = []
for arg in args:
if not isinstance(arg, six.string_types):
arg = CAST_SENTINEL + pickle.dumps(arg)
nargs.append(arg)
return SEP_SENTINEL.join(nargs)
def true_key(key):
key = key.rsplit(SEP_SENTINEL, 1)[-1]
if key.startswith(CAST_SENTINEL):
return pickle.loads(key)
return key
class BaseProvider(object):
def __init__(self, config):
self.config = config
self.data = {}
def exists(self, key):
return key in self.data
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
for key in self.data.keys():
if key.startswith(other) and key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key in keys:
yield key, self.get(key)
def get(self, key):
return self.data[key]
def set(self, key, value):
self.data[key] = value
def delete(self, key):
del self.data[key]
def load(self):
pass
def save(self):
pass
def root(self):
return StorageDict(self)
class StorageDict(UserDict):
def __init__(self, parent_or_provider, key=None):
if isinstance(parent_or_provider, BaseProvider):
self.provider = parent_or_provider
self.parent = None
else:
self.parent = parent_or_provider
self.provider = self.parent.provider
self._key = key or ROOT_SENTINEL
def keys(self):
return map(true_key, self.provider.keys(self.key))
def values(self):
for key in self.keys():
yield self.provider.get(key)
def items(self):
for key in self.keys():
yield (true_key(key), self.provider.get(key))
def ensure(self, key, typ=dict):
if key not in self:
self[key] = typ()
return self[key]
def update(self, obj):
for k, v in six.iteritems(obj):
self[k] = v
@property
def data(self):
obj = {}
for raw, value in self.provider.get_many(self.provider.keys(self.key)):
key = true_key(raw)
if value == OBJ_SENTINEL:
value = self.__class__(self, key=key).data
obj[key] = value
return obj
@property
def key(self):
if self.parent is not None:
return join_key(self.parent.key, self._key)
return self._key
def __setitem__(self, key, value):
if isinstance(value, dict):
obj = self.__class__(self, key)
obj.update(value)
value = OBJ_SENTINEL
self.provider.set(join_key(self.key, key), value)
def __getitem__(self, key):
res = self.provider.get(join_key(self.key, key))
if res == OBJ_SENTINEL:
return self.__class__(self, key)
return res
def __delitem__(self, key):
return self.provider.delete(join_key(self.key, key))
def __contains__(self, key):
return self.provider.exists(join_key(self.key, key))

54
disco/bot/providers/disk.py

@ -1,54 +0,0 @@
import os
import gevent
from disco.util.serializer import Serializer
from .base import BaseProvider
class DiskProvider(BaseProvider):
def __init__(self, config):
super(DiskProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage') + '.' + self.format
self.fsync = config.get('fsync', False)
self.fsync_changes = config.get('fsync_changes', 1)
self.autosave_task = None
self.change_count = 0
def autosave_loop(self, interval):
while True:
gevent.sleep(interval)
self.save()
def _on_change(self):
if self.fsync:
self.change_count += 1
if self.change_count >= self.fsync_changes:
self.save()
self.change_count = 0
def load(self):
if not os.path.exists(self.path):
return
if self.config.get('autosave', True):
self.autosave_task = gevent.spawn(
self.autosave_loop,
self.config.get('autosave_interval', 120))
with open(self.path, 'r') as f:
self.data = Serializer.loads(self.format, f.read())
def save(self):
with open(self.path, 'w') as f:
f.write(Serializer.dumps(self.format, self.data))
def set(self, key, value):
super(DiskProvider, self).set(key, value)
self._on_change()
def delete(self, key):
super(DiskProvider, self).delete(key)
self._on_change()

5
disco/bot/providers/memory.py

@ -1,5 +0,0 @@
from .base import BaseProvider
class MemoryProvider(BaseProvider):
pass

48
disco/bot/providers/redis.py

@ -1,48 +0,0 @@
from __future__ import absolute_import
import redis
from itertools import izip
from disco.util.serializer import Serializer
from .base import BaseProvider, SEP_SENTINEL
class RedisProvider(BaseProvider):
def __init__(self, config):
super(RedisProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.conn = None
def load(self):
self.conn = redis.Redis(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379),
db=self.config.get('db', 0))
def exists(self, key):
return self.conn.exists(key)
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
for key in self.conn.scan_iter(u'{}*'.format(other)):
key = key.decode('utf-8')
if key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
keys = list(keys)
if not len(keys):
raise StopIteration
for key, value in izip(keys, self.conn.mget(keys)):
yield (key, Serializer.loads(self.format, value))
def get(self, key):
return Serializer.loads(self.format, self.conn.get(key))
def set(self, key, value):
self.conn.set(key, Serializer.dumps(self.format, value))
def delete(self, key):
self.conn.delete(key)

52
disco/bot/providers/rocksdb.py

@ -1,52 +0,0 @@
from __future__ import absolute_import
import six
import rocksdb
from itertools import izip
from six.moves import map
from disco.util.serializer import Serializer
from .base import BaseProvider, SEP_SENTINEL
class RocksDBProvider(BaseProvider):
def __init__(self, config):
super(RocksDBProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage.db')
self.db = None
@staticmethod
def k(k):
return bytes(k) if six.PY3 else str(k.encode('utf-8'))
def load(self):
self.db = rocksdb.DB(self.path, rocksdb.Options(create_if_missing=True))
def exists(self, key):
return self.db.get(self.k(key)) is not None
# TODO prefix extractor
def keys(self, other):
count = other.count(SEP_SENTINEL) + 1
it = self.db.iterkeys()
it.seek_to_first()
for key in it:
key = key.decode('utf-8')
if key.startswith(other) and key.count(SEP_SENTINEL) == count:
yield key
def get_many(self, keys):
for key, value in izip(keys, self.db.multi_get(list(map(self.k, keys)))):
yield (key, Serializer.loads(self.format, value.decode('utf-8')))
def get(self, key):
return Serializer.loads(self.format, self.db.get(self.k(key)).decode('utf-8'))
def set(self, key, value):
self.db.put(self.k(key), Serializer.dumps(self.format, value))
def delete(self, key):
self.db.delete(self.k(key))

97
disco/bot/storage.py

@ -1,26 +1,87 @@
from .providers import load_provider
import os
from six.moves import UserDict
class Storage(object):
def __init__(self, ctx, config):
from disco.util.hashmap import HashMap
from disco.util.serializer import Serializer
class StorageHashMap(HashMap):
def __init__(self, data):
self.data = data
class ContextAwareProxy(UserDict):
def __init__(self, ctx):
self.ctx = ctx
self.config = config
self.provider = load_provider(config.provider)(config.config)
self.provider.load()
self.root = self.provider.root()
@property
def plugin(self):
return self.root.ensure('plugins').ensure(self.ctx['plugin'].name)
def data(self):
return self.ctx()
@property
def guild(self):
return self.plugin.ensure('guilds').ensure(self.ctx['guild'].id)
@property
def channel(self):
return self.plugin.ensure('channels').ensure(self.ctx['channel'].id)
class StorageDict(UserDict):
def __init__(self, parent, data):
self._parent = parent
self.data = data
@property
def user(self):
return self.plugin.ensure('users').ensure(self.ctx['user'].id)
def update(self, other):
self.data.update(other)
self._parent._update()
def __setitem__(self, key, value):
self.data[key] = value
self._parent._update()
def __delitem__(self, key):
del self.data[key]
self._parent._update()
class Storage(object):
def __init__(self, ctx, config):
self._ctx = ctx
self._path = config.path
self._serializer = config.serializer
self._fsync = config.fsync
self._data = {}
if os.path.exists(self._path):
with open(self._path, 'r') as f:
self._data = Serializer.loads(self._serializer, f.read())
def __getitem__(self, key):
if key not in self._data:
self._data[key] = {}
return StorageHashMap(StorageDict(self, self._data[key]))
def _update(self):
if self._fsync:
self.save()
def save(self):
if not self._path:
return
with open(self._path, 'w') as f:
f.write(Serializer.dumps(self._serializer, self._data))
def guild(self, key):
return ContextAwareProxy(
lambda: self['_g{}:{}'.format(self._ctx['guild'].id, key)]
)
def channel(self, key):
return ContextAwareProxy(
lambda: self['_c{}:{}'.format(self._ctx['channel'].id, key)]
)
def plugin(self, key):
return ContextAwareProxy(
lambda: self['_p{}:{}'.format(self._ctx['plugin'].name, key)]
)
def user(self, key):
return ContextAwareProxy(
lambda: self['_u{}:{}'.format(self._ctx['user'].id, key)]
)

44
disco/client.py

@ -15,13 +15,13 @@ from disco.util.backdoor import DiscoBackdoorServer
class ClientConfig(Config):
"""
Configuration for the :class:`Client`.
Configuration for the `Client`.
Attributes
----------
token : str
Discord authentication token, can be validated using the
:func:`disco.util.token.is_valid_token` function.
`disco.util.token.is_valid_token` function.
shard_id : int
The shard ID for the current client instance.
shard_count : int
@ -53,32 +53,32 @@ class Client(LoggingClass):
"""
Class representing the base entry point that should be used in almost all
implementation cases. This class wraps the functionality of both the REST API
(:class:`disco.api.client.APIClient`) and the realtime gateway API
(:class:`disco.gateway.client.GatewayClient`).
(`disco.api.client.APIClient`) and the realtime gateway API
(`disco.gateway.client.GatewayClient`).
Parameters
----------
config : :class:`ClientConfig`
config : `ClientConfig`
Configuration for this client instance.
Attributes
----------
config : :class:`ClientConfig`
config : `ClientConfig`
The runtime configuration for this client.
events : :class:`Emitter`
events : `Emitter`
An emitter which emits Gateway events.
packets : :class:`Emitter`
packets : `Emitter`
An emitter which emits Gateway packets.
state : :class:`State`
state : `State`
The state tracking object.
api : :class:`APIClient`
api : `APIClient`
The API client.
gw : :class:`GatewayClient`
gw : `GatewayClient`
The gateway client.
manhole_locals : dict
Dictionary of local variables for each manhole connection. This can be
modified to add/modify local variables.
manhole : Optional[:class:`BackdoorServer`]
manhole : Optional[`BackdoorServer`]
Gevent backdoor server (if the manhole is enabled).
"""
def __init__(self, config):
@ -105,7 +105,21 @@ class Client(LoggingClass):
localf=lambda: self.manhole_locals)
self.manhole.start()
def update_presence(self, game=None, status=None, afk=False, since=0.0):
def update_presence(self, status, game=None, afk=False, since=0.0):
"""
Updates the current clients presence.
Params
------
status : `user.Status`
The clients current status.
game : `user.Game`
If passed, the game object to set for the users presence.
afk : bool
Whether the client is currently afk.
since : float
How long the client has been afk for (in seconds).
"""
if game and not isinstance(game, Game):
raise TypeError('Game must be a Game model')
@ -126,12 +140,12 @@ class Client(LoggingClass):
def run(self):
"""
Run the client (e.g. the :class:`GatewayClient`) in a new greenlet.
Run the client (e.g. the `GatewayClient`) in a new greenlet.
"""
return gevent.spawn(self.gw.run)
def run_forever(self):
"""
Run the client (e.g. the :class:`GatewayClient`) in the current greenlet.
Run the client (e.g. the `GatewayClient`) in the current greenlet.
"""
return self.gw.run()

10
disco/gateway/client.py

@ -53,6 +53,8 @@ class GatewayClient(LoggingClass):
self.session_id = None
self.reconnects = 0
self.shutting_down = False
self.replaying = False
self.replayed_events = 0
# Cached gateway URL
self._cached_gateway_url = None
@ -81,6 +83,8 @@ class GatewayClient(LoggingClass):
obj = GatewayEvent.from_dispatch(self.client, packet)
self.log.debug('Dispatching %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj)
if self.replaying:
self.replayed_events += 1
def handle_heartbeat(self, _):
self._send(OPCode.HEARTBEAT, self.seq)
@ -105,8 +109,9 @@ class GatewayClient(LoggingClass):
self.reconnects = 0
def on_resumed(self, _):
self.log.info('Recieved RESUMED')
self.log.info('RESUME completed, replayed %s events', self.replayed_events)
self.reconnects = 0
self.replaying = False
def connect_and_run(self, gateway_url=None):
if not gateway_url:
@ -154,6 +159,7 @@ class GatewayClient(LoggingClass):
def on_open(self):
if self.seq and self.session_id:
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.replaying = True
self.send(OPCode.RESUME, {
'token': self.client.config.token,
'session_id': self.session_id,
@ -188,6 +194,8 @@ class GatewayClient(LoggingClass):
self.log.info('WS Closed: shutting down')
return
self.replaying = False
# Track reconnect attempts
self.reconnects += 1
self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects)

11
disco/gateway/events.py

@ -49,6 +49,8 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)):
"""
Create this GatewayEvent class from data and the client.
"""
cls.raw_data = obj
# If this event is wrapping a model, pull its fields
if hasattr(cls, '_wraps_model'):
alias, model = cls._wraps_model
@ -151,6 +153,7 @@ class GuildCreate(GatewayEvent):
and if None, this is a normal guild join event.
"""
unavailable = Field(bool)
presences = ListField(Presence)
@property
def created(self):
@ -607,6 +610,14 @@ class MessageReactionAdd(GatewayEvent):
user_id = Field(snowflake)
emoji = Field(MessageReactionEmoji)
def delete(self):
self.client.api.channels_messages_reactions_delete(
self.channel_id,
self.message_id,
self.emoji.to_string() if self.emoji.id else self.emoji.name,
self.user_id
)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)

68
disco/state.py

@ -42,10 +42,10 @@ class StateConfig(Config):
find they do not need and may be experiencing memory pressure can disable
this feature entirely using this attribute.
track_messages_size : int
The size of the deque for each channel. Using this you can calculate the
total number of possible :class:`StackMessage` objects kept in memory,
using: `total_mesages_size * total_channels`. This can be tweaked based
on usage to help prevent memory pressure.
The size of the messages deque for each channel. This value can be used
to calculate the total number of possible `StackMessage` objects kept in
memory, simply: `total_messages_size * total_channels`. This value can
be tweaked based on usage and to help prevent memory pressure.
sync_guild_members : bool
If true, guilds will be automatically synced when they are initially loaded
or joined. Generally this setting is OK for smaller bots, however bots in over
@ -60,31 +60,31 @@ class StateConfig(Config):
class State(object):
"""
The State class is used to track global state based on events emitted from
the :class:`GatewayClient`. State tracking is a core component of the Disco
client, providing the mechanism for most of the higher-level utility functions.
the `GatewayClient`. State tracking is a core component of the Disco client,
providing the mechanism for most of the higher-level utility functions.
Attributes
----------
EVENTS : list(str)
A list of all events the State object binds to
client : :class:`disco.client.Client`
client : `disco.client.Client`
The Client instance this state is attached to
config : :class:`StateConfig`
config : `StateConfig`
The configuration for this state instance
me : :class:`disco.types.user.User`
me : `User`
The currently logged in user
dms : dict(snowflake, :class:`disco.types.channel.Channel`)
dms : dict(snowflake, `Channel`)
Mapping of all known DM Channels
guilds : dict(snowflake, :class:`disco.types.guild.Guild`)
guilds : dict(snowflake, `Guild`)
Mapping of all known/loaded Guilds
channels : dict(snowflake, :class:`disco.types.channel.Channel`)
channels : dict(snowflake, `Channel`)
Weak mapping of all known/loaded Channels
users : dict(snowflake, :class:`disco.types.user.User`)
users : dict(snowflake, `User`)
Weak mapping of all known/loaded Users
voice_states : dict(str, :class:`disco.types.voice.VoiceState`)
voice_states : dict(str, `VoiceState`)
Weak mapping of all known/active Voice States
messages : Optional[dict(snowflake, :class:`deque`)]
Mapping of channel ids to deques containing :class:`StackMessage` objects
messages : Optional[dict(snowflake, deque)]
Mapping of channel ids to deques containing `StackMessage` objects
"""
EVENTS = [
'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove',
@ -184,8 +184,12 @@ class State(object):
self.channels.update(event.guild.channels)
for member in six.itervalues(event.guild.members):
if member.user.id not in self.users:
self.users[member.user.id] = member.user
for presence in event.presences:
self.users[presence.user.id].presence = presence
for voice_state in six.itervalues(event.guild.voice_states):
self.voice_states[voice_state.session_id] = voice_state
@ -282,6 +286,7 @@ class State(object):
for member in event.members:
member.guild_id = guild.id
guild.members[member.id] = member
if member.id not in self.users:
self.users[member.id] = member.user
def on_guild_role_create(self, event):
@ -309,18 +314,33 @@ class State(object):
if event.guild_id not in self.guilds:
return
for emoji in event.emojis:
emoji.guild_id = event.guild_id
self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis})
def on_presence_update(self, event):
if event.user.id in self.users:
self.users[event.user.id].update(event.presence.user)
self.users[event.user.id].presence = event.presence
event.presence.user = self.users[event.user.id]
if event.guild_id not in self.guilds:
# TODO: this is recursive, we hackfix in model, but its still lame ATM
user = event.presence.user
user.presence = event.presence
# if we have the user tracked locally, we can just use the presence
# update to update both their presence and the cached user object.
if user.id in self.users:
self.users[user.id].update(user)
else:
# Otherwise this user does not exist in our local cache, so we can
# use this opportunity to add them. They will quickly fall out of
# scope and be deleted if they aren't used below
self.users[user.id] = user
# Some updates come with a guild_id and roles the user is in, we should
# use this to update the guild member, but only if we have the guild
# cached.
if event.roles is UNSET or event.guild_id not in self.guilds:
return
if event.user.id not in self.guilds[event.guild_id].members:
if user.id not in self.guilds[event.guild_id].members:
return
self.guilds[event.guild_id].members[event.user.id].user.update(event.user)
self.guilds[event.guild_id].members[user.id].roles = event.roles

20
disco/types/base.py

@ -6,8 +6,9 @@ import functools
from holster.enum import BaseEnumMeta, EnumAttr
from datetime import datetime as real_datetime
from disco.util.functional import CachedSlotProperty
from disco.util.chains import Chainable
from disco.util.hashmap import HashMap
from disco.util.functional import CachedSlotProperty
DATETIME_FORMATS = [
'%Y-%m-%dT%H:%M:%S.%f',
@ -25,6 +26,9 @@ class Unset(object):
def __nonzero__(self):
return False
def __bool__(self):
return False
UNSET = Unset()
@ -270,15 +274,7 @@ class ModelMeta(type):
return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
class AsyncChainable(object):
__slots__ = []
def after(self, delay):
gevent.sleep(delay)
return self
class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
class Model(six.with_metaclass(ModelMeta, Chainable)):
__slots__ = ['client']
def __init__(self, *args, **kwargs):
@ -294,6 +290,10 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
self.load(obj)
self.validate()
def after(self, delay):
gevent.sleep(delay)
return self
def validate(self):
pass

15
disco/types/channel.py

@ -1,3 +1,4 @@
import re
import six
from six.moves import map
@ -11,6 +12,9 @@ from disco.types.permissions import Permissions, Permissible, PermissionValue
from disco.voice.client import VoiceClient
NSFW_RE = re.compile('^nsfw(-|$)')
ChannelType = Enum(
GUILD_TEXT=0,
DM=1,
@ -179,6 +183,13 @@ class Channel(SlottedModel, Permissible):
"""
return self.type in (ChannelType.DM, ChannelType.GROUP_DM)
@property
def is_nsfw(self):
"""
Whether this channel is an NSFW channel.
"""
return self.type == ChannelType.GUILD_TEXT and NSFW_RE.match(self.name)
@property
def is_voice(self):
"""
@ -244,7 +255,7 @@ class Channel(SlottedModel, Permissible):
def create_webhook(self, name=None, avatar=None):
return self.client.api.channels_webhooks_create(self.id, name, avatar)
def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None):
def send_message(self, *args, **kwargs):
"""
Send a message in this channel.
@ -262,7 +273,7 @@ class Channel(SlottedModel, Permissible):
:class:`disco.types.message.Message`
The created message.
"""
return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed)
return self.client.api.channels_messages_create(self.id, *args, **kwargs)
def connect(self, *args, **kwargs):
"""

37
disco/types/guild.py

@ -9,7 +9,7 @@ from disco.util.functional import cached_property
from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum, datetime
)
from disco.types.user import User, Presence
from disco.types.user import User
from disco.types.voice import VoiceState
from disco.types.channel import Channel
from disco.types.message import Emoji
@ -23,6 +23,17 @@ VerificationLevel = Enum(
HIGH=3,
)
ExplicitContentFilterLevel = Enum(
NONE=0,
WITHOUT_ROLES=1,
ALL=2
)
DefaultMessageNotificationsLevel = Enum(
ALL_MESSAGES=0,
ONLY_MENTIONS=1,
)
class GuildEmoji(Emoji):
"""
@ -51,6 +62,12 @@ class GuildEmoji(Emoji):
def __str__(self):
return u'<:{}:{}>'.format(self.name, self.id)
def update(self, **kwargs):
return self.client.api.guilds_emojis_modify(self.guild_id, self.id, **kwargs)
def delete(self):
return self.client.api.guilds_emojis_delete(self.guild_id, self.id)
@property
def url(self):
return 'https://discordapp.com/api/emojis/{}.png'.format(self.id)
@ -102,7 +119,7 @@ class Role(SlottedModel):
@property
def mention(self):
return '<@{}>'.format(self.id)
return '<@&{}>'.format(self.id)
@cached_property
def guild(self):
@ -289,6 +306,8 @@ class Guild(SlottedModel, Permissible):
afk_timeout = Field(int)
embed_enabled = Field(bool)
verification_level = Field(enum(VerificationLevel))
explicit_content_filter = Field(enum(ExplicitContentFilterLevel))
default_message_notifications = Field(enum(DefaultMessageNotificationsLevel))
mfa_level = Field(int)
features = ListField(str)
members = AutoDictField(GuildMember, 'id')
@ -297,7 +316,6 @@ class Guild(SlottedModel, Permissible):
emojis = AutoDictField(GuildEmoji, 'id')
voice_states = AutoDictField(VoiceState, 'session_id')
member_count = Field(int)
presences = ListField(Presence)
synced = Field(bool, default=False)
@ -310,6 +328,10 @@ class Guild(SlottedModel, Permissible):
self.attach(six.itervalues(self.emojis), {'guild_id': self.id})
self.attach(six.itervalues(self.voice_states), {'guild_id': self.id})
@cached_property
def owner(self):
return self.members.get(self.owner_id)
def get_permissions(self, member):
"""
Get the permissions a user has in this guild.
@ -417,3 +439,12 @@ class Guild(SlottedModel, Permissible):
def create_channel(self, *args, **kwargs):
return self.client.api.guilds_channels_create(self.id, *args, **kwargs)
def leave(self):
return self.client.api.users_me_guilds_delete(self.id)
def get_invites(self):
return self.client.api.guilds_invites_list(self.id)
def get_emojis(self):
return self.client.api.guilds_emojis_list(self.id)

2
disco/types/invite.py

@ -40,7 +40,7 @@ class Invite(SlottedModel):
created_at = Field(datetime)
@classmethod
def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False):
def create_for_channel(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False):
return channel.client.api.channels_invites_create(
channel.id,
max_age=max_age,

7
disco/types/message.py

@ -1,5 +1,6 @@
import re
import six
import warnings
import functools
import unicodedata
@ -315,6 +316,12 @@ class Message(SlottedModel):
)
def create_reaction(self, emoji):
warnings.warn(
'Message.create_reaction will be deprecated soon, use Message.add_reaction',
DeprecationWarning)
return self.add_reaction(emoji)
def add_reaction(self, emoji):
if isinstance(emoji, Emoji):
emoji = emoji.to_string()
self.client.api.channels_messages_reactions_create(

4
disco/types/permissions.py

@ -40,6 +40,10 @@ class PermissionValue(object):
self.value = value
def can(self, *perms):
# Administrator permission overwrites all others
if self.administrator:
return True
for perm in perms:
if isinstance(perm, EnumAttr):
perm = perm.value

3
disco/types/user.py

@ -45,6 +45,9 @@ class User(SlottedModel, with_equality('id'), with_hash('id')):
def mention(self):
return '<@{}>'.format(self.id)
def open_dm(self):
return self.client.api.users_me_dms_create(self.id)
def __str__(self):
return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4))

23
disco/types/webhook.py

@ -1,8 +1,13 @@
import re
from disco.types.base import SlottedModel, Field, snowflake
from disco.types.user import User
from disco.util.functional import cached_property
WEBHOOK_URL_RE = re.compile(r'\/api\/webhooks\/(\d+)\/(.[^/]+)')
class Webhook(SlottedModel):
id = Field(snowflake)
guild_id = Field(snowflake)
@ -12,6 +17,19 @@ class Webhook(SlottedModel):
avatar = Field(str)
token = Field(str)
@classmethod
def execute_url(cls, url, **kwargs):
from disco.api.client import APIClient
results = WEBHOOK_URL_RE.findall(url)
if len(results) != 1:
return Exception('Invalid Webhook URL')
return cls(id=results[0][0], token=results[0][1]).execute(
client=APIClient(None),
**kwargs
)
@cached_property
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@ -32,10 +50,11 @@ class Webhook(SlottedModel):
else:
return self.client.api.webhooks_modify(self.id, name, avatar)
def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False):
def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False, client=None):
# TODO: support file stuff properly
client = client or self.client.api
return self.client.api.webhooks_token_execute(self.id, self.token, {
return client.webhooks_token_execute(self.id, self.token, {
'content': content,
'username': username,
'avatar_url': avatar_url,

71
disco/util/chains.py

@ -0,0 +1,71 @@
import gevent
"""
Object.chain -> creates a chain where each action happens after the last
pass_result = False -> whether the result of the last action is passed, or the original
Object.async_chain -> creates an async chain where each action happens at the same time
"""
class Chainable(object):
__slots__ = []
def chain(self, pass_result=True):
return Chain(self, pass_result=pass_result, async_=False)
def async_chain(self):
return Chain(self, pass_result=False, async_=True)
class Chain(object):
def __init__(self, obj, pass_result=True, async_=False):
self._obj = obj
self._pass_result = pass_result
self._async = async_
self._parts = []
@property
def obj(self):
if isinstance(self._obj, Chain):
return self._obj._next()
return self._obj
def __getattr__(self, item):
func = getattr(self.obj, item)
if not func or not callable(func):
return func
def _wrapped(*args, **kwargs):
inst = gevent.spawn(func, *args, **kwargs)
self._parts.append(inst)
# If async, just return instantly
if self._async:
return self
# Otherwise return a chain
return Chain(self)
return _wrapped
def _next(self):
res = self._parts[0].get()
if self._pass_result:
return res
return self
def then(self, func, *args, **kwargs):
inst = gevent.spawn(func, *args, **kwargs)
self._parts.append(inst)
if self._async:
return self
return Chain(self)
def first(self):
return self._obj
def get(self, timeout=None):
return gevent.wait(self._parts, timeout=timeout)
def wait(self, timeout=None):
gevent.joinall(self._parts, timeout=None)

9
disco/util/logging.py

@ -1,5 +1,6 @@
from __future__ import absolute_import
import warnings
import logging
@ -9,10 +10,18 @@ LEVEL_OVERRIDES = {
LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'
def setup_logging(**kwargs):
kwargs.setdefault('format', LOG_FORMAT)
# Setup warnings module correctly
warnings.simplefilter('always', DeprecationWarning)
logging.captureWarnings(True)
# Pass through our basic configuration
logging.basicConfig(**kwargs)
# Override some noisey loggers
for logger, level in LEVEL_OVERRIDES.items():
logging.getLogger(logger).setLevel(level)

32
disco/util/sanitize.py

@ -0,0 +1,32 @@
import re
# Zero width (non-rendering) space that can be used to escape mentions
ZERO_WIDTH_SPACE = u'\u200B'
# A grave-looking character that can be used to escape codeblocks
MODIFIER_GRAVE_ACCENT = u'\u02CB'
# Regex which matches all possible mention combinations, this may be over-zealous
# but its better safe than sorry.
MENTION_RE = re.compile('<?([@|#][!|&]?[0-9]+|@everyone|@here)>?')
def _re_sub_mention(mention):
mention = mention.group(1)
if '#' in mention:
return (u'#' + ZERO_WIDTH_SPACE).join(mention.split('#', 1))
elif '@' in mention:
return (u'@' + ZERO_WIDTH_SPACE).join(mention.split('@', 1))
else:
return mention
def S(text, escape_mentions=True, escape_codeblocks=False):
if escape_mentions:
text = MENTION_RE.sub(_re_sub_mention, text)
if escape_codeblocks:
text = text.replace('`', MODIFIER_GRAVE_ACCENT)
return text

9
disco/util/snowflake.py

@ -2,6 +2,7 @@ import six
from datetime import datetime
UNIX_EPOCH = datetime(1970, 1, 1)
DISCORD_EPOCH = 1420070400000
@ -20,6 +21,14 @@ def to_unix_ms(snowflake):
return (int(snowflake) >> 22) + DISCORD_EPOCH
def from_datetime(date):
return from_timestamp((date - UNIX_EPOCH).total_seconds())
def from_timestamp(ts):
return long(ts * 1000.0 - DISCORD_EPOCH) << 22
def to_snowflake(i):
if isinstance(i, six.integer_types):
return i

3
disco/voice/__init__.py

@ -0,0 +1,3 @@
from disco.voice.client import *
from disco.voice.player import *
from disco.voice.playable import *

116
disco/voice/client.py

@ -1,8 +1,15 @@
from __future__ import print_function
import gevent
import socket
import struct
import time
try:
import nacl.secret
except ImportError:
print('WARNING: nacl is not installed, voice support is disabled')
from holster.enum import Enum
from holster.emitter import Emitter
@ -22,11 +29,6 @@ VoiceState = Enum(
VOICE_CONNECTED=6,
)
# TODO:
# - player implementation
# - encryption
# - cleanup
class VoiceException(Exception):
def __init__(self, msg, client):
@ -38,12 +40,40 @@ class UDPVoiceClient(LoggingClass):
def __init__(self, vc):
super(UDPVoiceClient, self).__init__()
self.vc = vc
# The underlying UDP socket
self.conn = None
# Connection information
self.ip = None
self.port = None
self.run_task = None
self.connected = False
def send_frame(self, frame, sequence=None, timestamp=None):
# Convert the frame to a bytearray
frame = bytearray(frame)
# First, pack the header (TODO: reuse bytearray?)
header = bytearray(24)
header[0] = 0x80
header[1] = 0x78
struct.pack_into('>H', header, 2, sequence or self.vc.sequence)
struct.pack_into('>I', header, 4, timestamp or self.vc.timestamp)
struct.pack_into('>i', header, 8, self.vc.ssrc)
# Now encrypt the payload with the nonce as a header
raw = self.vc.secret_box.encrypt(bytes(frame), bytes(header)).ciphertext
# Send the header (sans nonce padding) plus the payload
self.send(header[:12] + raw)
# Increment our sequence counter
self.vc.sequence += 1
if self.vc.sequence >= 65535:
self.vc.sequence = 0
def run(self):
while True:
self.conn.recvfrom(4096)
@ -54,12 +84,15 @@ class UDPVoiceClient(LoggingClass):
def disconnect(self):
self.run_task.kill()
def connect(self, host, port, timeout=10):
def connect(self, host, port, timeout=10, addrinfo=None):
self.ip = socket.gethostbyname(host)
self.port = port
self.conn = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if addrinfo:
ip, port = addrinfo
else:
# Send discovery packet
packet = bytearray(70)
struct.pack_into('>I', packet, 0, self.vc.ssrc)
@ -86,30 +119,45 @@ class VoiceClient(LoggingClass):
def __init__(self, channel, encoder=None):
super(VoiceClient, self).__init__()
assert channel.is_voice, 'Cannot spawn a VoiceClient for a non-voice channel'
if not channel.is_voice:
raise ValueError('Cannot spawn a VoiceClient for a non-voice channel')
self.channel = channel
self.client = self.channel.client
self.encoder = encoder or JSONEncoder
# Bind to some WS packets
self.packets = Emitter(gevent.spawn)
self.packets.on(VoiceOPCode.READY, self.on_voice_ready)
self.packets.on(VoiceOPCode.SESSION_DESCRIPTION, self.on_voice_sdp)
# State
# State + state change emitter
self.state = VoiceState.DISCONNECTED
self.connected = gevent.event.Event()
self.state_emitter = Emitter(gevent.spawn)
# Connection metadata
self.token = None
self.endpoint = None
self.ssrc = None
self.port = None
self.secret_box = None
self.udp = None
# Voice data state
self.sequence = 0
self.timestamp = 0
self.update_listener = None
# Websocket connection
self.ws = None
self.heartbeat_task = None
def set_state(self, state):
prev_state = self.state
self.state = state
self.state_emitter.emit(state, prev_state)
def heartbeat(self, interval):
while True:
self.send(VoiceOPCode.HEARTBEAT, time.time() * 1000)
@ -128,7 +176,7 @@ class VoiceClient(LoggingClass):
}), self.encoder.OPCODE)
def on_voice_ready(self, data):
self.state = VoiceState.CONNECTING
self.set_state(VoiceState.CONNECTING)
self.ssrc = data['ssrc']
self.port = data['port']
@ -146,18 +194,20 @@ class VoiceClient(LoggingClass):
'data': {
'port': port,
'address': ip,
'mode': 'plain'
'mode': 'xsalsa20_poly1305'
}
})
def on_voice_sdp(self, _):
def on_voice_sdp(self, sdp):
# Create a secret box for encryption/decryption
self.secret_box = nacl.secret.SecretBox(bytes(bytearray(sdp['secret_key'])))
# Toggle speaking state so clients learn of our SSRC
self.set_speaking(True)
self.set_speaking(False)
gevent.sleep(0.25)
self.state = VoiceState.CONNECTED
self.connected.set()
self.set_state(VoiceState.CONNECTED)
def on_voice_server_update(self, data):
if self.channel.guild_id != data.guild_id or not data.token:
@ -167,30 +217,28 @@ class VoiceClient(LoggingClass):
return
self.token = data.token
self.state = VoiceState.AUTHENTICATING
self.set_state(VoiceState.AUTHENTICATING)
self.endpoint = data.endpoint.split(':', 1)[0]
self.ws = Websocket(
'wss://' + self.endpoint,
on_message=self.on_message,
on_error=self.on_error,
on_open=self.on_open,
on_close=self.on_close,
)
self.ws = Websocket('wss://' + self.endpoint)
self.ws.emitter.on('on_open', self.on_open)
self.ws.emitter.on('on_error', self.on_error)
self.ws.emitter.on('on_close', self.on_close)
self.ws.emitter.on('on_message', self.on_message)
self.ws.run_forever()
def on_message(self, _, msg):
def on_message(self, msg):
try:
data = self.encoder.decode(msg)
self.packets.emit(VoiceOPCode[data['op']], data['d'])
except:
self.log.exception('Failed to parse voice gateway message: ')
def on_error(self, _, err):
# TODO
def on_error(self, err):
# TODO: raise an exception here
self.log.warning('Voice websocket error: {}'.format(err))
def on_open(self, _):
def on_open(self):
self.send(VoiceOPCode.IDENTIFY, {
'server_id': self.channel.guild_id,
'user_id': self.client.state.me.id,
@ -198,12 +246,15 @@ class VoiceClient(LoggingClass):
'token': self.token
})
def on_close(self, _, code, error):
# TODO
def on_close(self, code, error):
self.log.warning('Voice websocket disconnected (%s, %s)', code, error)
if self.state == VoiceState.CONNECTED:
self.log.info('Attempting voice reconnection')
self.connect()
def connect(self, timeout=5, mute=False, deaf=False):
self.state = VoiceState.AWAITING_ENDPOINT
self.set_state(VoiceState.AWAITING_ENDPOINT)
self.update_listener = self.client.events.on('VoiceServerUpdate', self.on_voice_server_update)
@ -214,11 +265,11 @@ class VoiceClient(LoggingClass):
'channel_id': int(self.channel.id),
})
if not self.connected.wait(timeout) or self.state != VoiceState.CONNECTED:
if not self.state_emitter.once(VoiceState.CONNECTED, timeout=timeout):
raise VoiceException('Failed to connect to voice', self)
def disconnect(self):
self.state = VoiceState.DISCONNECTED
self.set_state(VoiceState.DISCONNECTED)
if self.heartbeat_task:
self.heartbeat_task.kill()
@ -236,3 +287,6 @@ class VoiceClient(LoggingClass):
'guild_id': int(self.channel.guild_id),
'channel_id': None,
})
def send_frame(self, *args, **kwargs):
self.udp.send_frame(*args, **kwargs)

149
disco/voice/opus.py

@ -0,0 +1,149 @@
import sys
import array
import ctypes
import ctypes.util
from holster.enum import Enum
from disco.util.logging import LoggingClass
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
c_float_ptr = ctypes.POINTER(ctypes.c_float)
class EncoderStruct(ctypes.Structure):
pass
class DecoderStruct(ctypes.Structure):
pass
EncoderStructPtr = ctypes.POINTER(EncoderStruct)
DecoderStructPtr = ctypes.POINTER(DecoderStruct)
class BaseOpus(LoggingClass):
BASE_EXPORTED = {
'opus_strerror': ([ctypes.c_int], ctypes.c_char_p),
}
EXPORTED = {}
def __init__(self, library_path=None):
self.path = library_path or self.find_library()
self.lib = ctypes.cdll.LoadLibrary(self.path)
methods = {}
methods.update(self.BASE_EXPORTED)
methods.update(self.EXPORTED)
for name, item in methods.items():
func = getattr(self.lib, name)
if item[0]:
func.argtypes = item[0]
func.restype = item[1]
setattr(self, name, func)
@staticmethod
def find_library():
if sys.platform == 'win32':
raise Exception('Cannot auto-load opus on Windows, please specify full library path')
return ctypes.util.find_library('opus')
Application = Enum(
AUDIO=2049,
VOIP=2048,
LOWDELAY=2051
)
Control = Enum(
SET_BITRATE=4002,
SET_BANDWIDTH=4008,
SET_FEC=4012,
SET_PLP=4014,
)
class OpusEncoder(BaseOpus):
EXPORTED = {
'opus_encoder_get_size': ([ctypes.c_int], ctypes.c_int),
'opus_encoder_create': ([ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr),
'opus_encode': ([EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32),
'opus_encoder_ctl': (None, ctypes.c_int32),
'opus_encoder_destroy': ([EncoderStructPtr], None),
}
def __init__(self, sampling_rate, channels, application=Application.AUDIO, library_path=None):
super(OpusEncoder, self).__init__(library_path)
self.sampling_rate = sampling_rate
self.channels = channels
self.application = application
self._inst = None
@property
def inst(self):
if not self._inst:
self._inst = self.create()
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
return self._inst
def set_bitrate(self, kbps):
kbps = min(128, max(16, int(kbps)))
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_BITRATE), kbps * 1024)
if ret < 0:
raise Exception('Failed to set bitrate to {}: {}'.format(kbps, ret))
def set_fec(self, value):
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_FEC), int(value))
if ret < 0:
raise Exception('Failed to set FEC to {}: {}'.format(value, ret))
def set_expected_packet_loss_percent(self, perc):
ret = self.opus_encoder_ctl(self.inst, int(Control.SET_PLP), min(100, max(0, int(perc * 100))))
if ret < 0:
raise Exception('Failed to set PLP to {}: {}'.format(perc, ret))
def create(self):
ret = ctypes.c_int()
result = self.opus_encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret))
if ret.value != 0:
raise Exception('Failed to create opus encoder: {}'.format(ret.value))
return result
def __del__(self):
if hasattr(self, '_inst') and self._inst:
self.opus_encoder_destroy(self._inst)
self._inst = None
def encode(self, pcm, frame_size):
max_data_bytes = len(pcm)
pcm = ctypes.cast(pcm, c_int16_ptr)
data = (ctypes.c_char * max_data_bytes)()
ret = self.opus_encode(self.inst, pcm, frame_size, data, max_data_bytes)
if ret < 0:
raise Exception('Failed to encode: {}'.format(ret))
# TODO: py3
return array.array('b', data[:ret]).tostring()
class OpusDecoder(BaseOpus):
pass

354
disco/voice/playable.py

@ -0,0 +1,354 @@
import abc
import six
import types
import gevent
import struct
import subprocess
from gevent.lock import Semaphore
from gevent.queue import Queue
from disco.voice.opus import OpusEncoder
try:
from cStringIO import cStringIO as BufferedIO
except:
if six.PY2:
from StringIO import StringIO as BufferedIO
else:
from io import BytesIO as BufferedIO
OPUS_HEADER_SIZE = struct.calcsize('<h')
class AbstractOpus(object):
def __init__(self, sampling_rate=48000, frame_length=20, channels=2):
self.sampling_rate = sampling_rate
self.frame_length = frame_length
self.channels = 2
self.sample_size = 2 * self.channels
self.samples_per_frame = int(self.sampling_rate / 1000 * self.frame_length)
self.frame_size = self.samples_per_frame * self.sample_size
class BaseUtil(object):
def pipe(self, other, *args, **kwargs):
child = other(self, *args, **kwargs)
setattr(child, 'metadata', self.metadata)
setattr(child, '_parent', self)
return child
@property
def metadata(self):
return getattr(self, '_metadata', None)
@metadata.setter
def metadata(self, value):
self._metadata = value
@six.add_metaclass(abc.ABCMeta)
class BasePlayable(BaseUtil):
@abc.abstractmethod
def next_frame(self):
raise NotImplementedError
@six.add_metaclass(abc.ABCMeta)
class BaseInput(BaseUtil):
@abc.abstractmethod
def read(self, size):
raise NotImplementedError
@abc.abstractmethod
def fileobj(self):
raise NotImplementedError
class OpusFilePlayable(BasePlayable, AbstractOpus):
"""
An input which reads opus data from a file or file-like object.
"""
def __init__(self, fobj, *args, **kwargs):
super(OpusFilePlayable, self).__init__(*args, **kwargs)
self.fobj = fobj
self.done = False
def next_frame(self):
if self.done:
return None
header = self.fobj.read(OPUS_HEADER_SIZE)
if len(header) < OPUS_HEADER_SIZE:
self.done = True
return None
data_size = struct.unpack('<h', header)[0]
data = self.fobj.read(data_size)
if len(data) < data_size:
self.done = True
return None
return data
class FFmpegInput(BaseInput, AbstractOpus):
def __init__(self, source='-', command='avconv', streaming=False, **kwargs):
super(FFmpegInput, self).__init__(**kwargs)
if source:
self.source = source
self.streaming = streaming
self.command = command
self._buffer = None
self._proc = None
def read(self, sz):
if self.streaming:
raise TypeError('Cannot read from a streaming FFmpegInput')
# First read blocks until the subprocess finishes
if not self._buffer:
data, _ = self.proc.communicate()
self._buffer = BufferedIO(data)
# Subsequent reads can just do dis thang
return self._buffer.read(sz)
def fileobj(self):
if self.streaming:
return self.proc.stdout
else:
return self
@property
def proc(self):
if not self._proc:
if callable(self.source):
self.source = self.source(self)
if isinstance(self.source, (tuple, list)):
self.source, self.metadata = self.source
args = [
self.command,
'-i', str(self.source),
'-f', 's16le',
'-ar', str(self.sampling_rate),
'-ac', str(self.channels),
'-loglevel', 'warning',
'pipe:1'
]
self._proc = subprocess.Popen(args, stdin=None, stdout=subprocess.PIPE)
return self._proc
class YoutubeDLInput(FFmpegInput):
def __init__(self, url=None, ie_info=None, *args, **kwargs):
super(YoutubeDLInput, self).__init__(None, *args, **kwargs)
self._url = url
self._ie_info = ie_info
self._info = None
self._info_lock = Semaphore()
@property
def info(self):
with self._info_lock:
if not self._info:
import youtube_dl
ydl = youtube_dl.YoutubeDL({'format': 'webm[abr>0]/bestaudio/best'})
if self._url:
obj = ydl.extract_info(self._url, download=False, process=False)
if 'entries' in obj:
self._ie_info = obj['entries'][0]
else:
self._ie_info = obj
self._info = ydl.process_ie_result(self._ie_info, download=False)
return self._info
@property
def _metadata(self):
return self.info
@classmethod
def many(cls, url, *args, **kwargs):
import youtube_dl
ydl = youtube_dl.YoutubeDL({'format': 'webm[abr>0]/bestaudio/best'})
info = ydl.extract_info(url, download=False, process=False)
if 'entries' not in info:
yield cls(ie_info=info, *args, **kwargs)
raise StopIteration
for item in info['entries']:
yield cls(ie_info=item, *args, **kwargs)
@property
def source(self):
return self.info['url']
class BufferedOpusEncoderPlayable(BasePlayable, OpusEncoder, AbstractOpus):
def __init__(self, source, *args, **kwargs):
self.source = source
self.frames = Queue(kwargs.pop('queue_size', 4096))
# Call the AbstractOpus constructor, as we need properties it sets
AbstractOpus.__init__(self, *args, **kwargs)
# Then call the OpusEncoder constructor, which requires some properties
# that AbstractOpus sets up
OpusEncoder.__init__(self, self.sampling_rate, self.channels)
# Spawn the encoder loop
gevent.spawn(self._encoder_loop)
def _encoder_loop(self):
while self.source:
raw = self.source.read(self.frame_size)
if len(raw) < self.frame_size:
break
self.frames.put(self.encode(raw, self.samples_per_frame))
gevent.idle()
self.source = None
self.frames.put(None)
def next_frame(self):
return self.frames.get()
class DCADOpusEncoderPlayable(BasePlayable, AbstractOpus, OpusEncoder):
def __init__(self, source, *args, **kwargs):
self.source = source
self.command = kwargs.pop('command', 'dcad')
super(DCADOpusEncoderPlayable, self).__init__(*args, **kwargs)
self._done = False
self._proc = None
@property
def proc(self):
if not self._proc:
source = obj = self.source.fileobj()
if not hasattr(obj, 'fileno'):
source = subprocess.PIPE
self._proc = subprocess.Popen([
self.command,
'--channels', str(self.channels),
'--rate', str(self.sampling_rate),
'--size', str(self.samples_per_frame),
'--bitrate', '128',
'--fec',
'--packet-loss-percent', '30',
'--input', 'pipe:0',
'--output', 'pipe:1',
], stdin=source, stdout=subprocess.PIPE)
def writer():
while True:
data = obj.read(2048)
if len(data) > 0:
self._proc.stdin.write(data)
if len(data) < 2048:
break
if source == subprocess.PIPE:
gevent.spawn(writer)
return self._proc
def next_frame(self):
if self._done:
return None
header = self.proc.stdout.read(OPUS_HEADER_SIZE)
if len(header) < OPUS_HEADER_SIZE:
self._done = True
return
size = struct.unpack('<h', header)[0]
data = self.proc.stdout.read(size)
if len(data) < size:
self._done = True
return
return data
class FileProxyPlayable(BasePlayable, AbstractOpus):
def __init__(self, other, output, *args, **kwargs):
self.flush = kwargs.pop('flush', False)
self.on_complete = kwargs.pop('on_complete', None)
super(FileProxyPlayable, self).__init__(*args, **kwargs)
self.other = other
self.output = output
def next_frame(self):
frame = self.other.next_frame()
if frame:
self.output.write(struct.pack('<h', len(frame)))
self.output.write(frame)
if self.flush:
self.output.flush()
else:
self.output.flush()
self.on_complete()
self.output.close()
return frame
class PlaylistPlayable(BasePlayable, AbstractOpus):
def __init__(self, items, *args, **kwargs):
super(PlaylistPlayable, self).__init__(*args, **kwargs)
self.items = items
self.now_playing = None
def _get_next(self):
if isinstance(self.items, types.GeneratorType):
return next(self.items, None)
return self.items.pop()
def next_frame(self):
if not self.items:
return
if not self.now_playing:
self.now_playing = self._get_next()
if not self.now_playing:
return
frame = self.now_playing.next_frame()
if not frame:
return self.next_frame()
return frame
class MemoryBufferedPlayable(BasePlayable, AbstractOpus):
def __init__(self, other, *args, **kwargs):
from gevent.queue import Queue
super(MemoryBufferedPlayable, self).__init__(*args, **kwargs)
self.frames = Queue()
self.other = other
gevent.spawn(self._buffer)
def _buffer(self):
while True:
frame = self.other.next_frame()
if not frame:
break
self.frames.put(frame)
self.frames.put(None)
def next_frame(self):
return self.frames.get()

126
disco/voice/player.py

@ -0,0 +1,126 @@
import time
import gevent
from six.moves import queue
from holster.enum import Enum
from holster.emitter import Emitter
from disco.voice.client import VoiceState
MAX_TIMESTAMP = 4294967295
class Player(object):
Events = Enum(
'START_PLAY',
'STOP_PLAY',
'PAUSE_PLAY',
'RESUME_PLAY',
'DISCONNECT'
)
def __init__(self, client):
self.client = client
# Queue contains playable items
self.queue = queue.Queue()
# Whether we're playing music (true for lifetime)
self.playing = True
# Set to an event when playback is paused
self.paused = None
# Current playing item
self.now_playing = None
# Current play task
self.play_task = None
# Core task
self.run_task = gevent.spawn(self.run)
# Event triggered when playback is complete
self.complete = gevent.event.Event()
# Event emitter for metadata
self.events = Emitter(gevent.spawn)
def disconnect(self):
self.client.disconnect()
self.events.emit(self.Events.DISCONNECT)
def skip(self):
self.play_task.kill()
def pause(self):
if self.paused:
return
self.paused = gevent.event.Event()
self.events.emit(self.Events.PAUSE_PLAY)
def resume(self):
self.paused.set()
self.paused = None
self.events.emit(self.Events.RESUME_PLAY)
def play(self, item):
# Grab the first frame before we start anything else, sometimes playables
# can do some lengthy async tasks here to setup the playable and we
# don't want that lerp the first N frames of the playable into playing
# faster
frame = item.next_frame()
if frame is None:
return
start = time.time()
loops = 0
while True:
loops += 1
if self.paused:
self.client.set_speaking(False)
self.paused.wait()
gevent.sleep(2)
self.client.set_speaking(True)
start = time.time()
loops = 0
if self.client.state == VoiceState.DISCONNECTED:
return
if self.client.state != VoiceState.CONNECTED:
self.client.state_emitter.wait(VoiceState.CONNECTED)
self.client.send_frame(frame)
self.client.timestamp += item.samples_per_frame
if self.client.timestamp > MAX_TIMESTAMP:
self.client.timestamp = 0
frame = item.next_frame()
if frame is None:
return
next_time = start + 0.02 * loops
delay = max(0, 0.02 + (next_time - time.time()))
gevent.sleep(delay)
def run(self):
self.client.set_speaking(True)
while self.playing:
self.now_playing = self.queue.get()
self.events.emit(self.Events.START_PLAY, self.now_playing)
self.play_task = gevent.spawn(self.play, self.now_playing)
self.play_task.join()
self.events.emit(self.Events.STOP_PLAY, self.now_playing)
if self.client.state == VoiceState.DISCONNECTED:
self.playing = False
self.complete.set()
return
self.client.set_speaking(False)
self.disconnect()

4
docs/SUMMARY.md

@ -4,7 +4,5 @@
* [Installation and Setup](INSTALLATION.md)
* [Building a Bot](BUILDING_A_BOT.md)
* API Docs
* [Client](CLIENT.md)
* Types
* [Message](types/MESSAGE.md)
* [Client](api/disco_client.md)

3
docs/types/MESSAGE.md

@ -1,3 +0,0 @@
# Message
TODO

2
examples/basic_plugin.py

@ -50,7 +50,7 @@ class BasicPlugin(Plugin):
if not users:
event.msg.reply("Couldn't find user for your query: `{}`".format(query))
elif len(users) > 1:
event.msg.reply('I found too many userse ({}) for your query: `{}`'.format(len(users), query))
event.msg.reply('I found too many users ({}) for your query: `{}`'.format(len(users), query))
else:
user = users[0]
parts = []

52
examples/music.py

@ -0,0 +1,52 @@
from disco.bot import Plugin
from disco.bot.command import CommandError
from disco.voice.player import Player
from disco.voice.playable import YoutubeDLInput, BufferedOpusEncoderPlayable
from disco.voice.client import VoiceException
class MusicPlugin(Plugin):
def load(self, ctx):
super(MusicPlugin, self).load(ctx)
self.guilds = {}
@Plugin.command('join')
def on_join(self, event):
if event.guild.id in self.guilds:
return event.msg.reply("I'm already playing music here.")
state = event.guild.get_member(event.author).get_voice_state()
if not state:
return event.msg.reply('You must be connected to voice to use that command.')
try:
client = state.channel.connect()
except VoiceException as e:
return event.msg.reply('Failed to connect to voice: `{}`'.format(e))
self.guilds[event.guild.id] = Player(client)
self.guilds[event.guild.id].complete.wait()
del self.guilds[event.guild.id]
def get_player(self, guild_id):
if guild_id not in self.guilds:
raise CommandError("I'm not currently playing music here.")
return self.guilds.get(guild_id)
@Plugin.command('leave')
def on_leave(self, event):
player = self.get_player(event.guild.id)
player.disconnect()
@Plugin.command('play', '<url:str>')
def on_play(self, event, url):
item = YoutubeDLInput(url).pipe(BufferedOpusEncoderPlayable)
self.get_player(event.guild.id).queue.put(item)
@Plugin.command('pause')
def on_pause(self, event):
self.get_player(event.guild.id).pause()
@Plugin.command('resume')
def on_resume(self, event):
self.get_player(event.guild.id).resume()

33
examples/storage.py

@ -0,0 +1,33 @@
from disco.bot import Plugin
class BasicPlugin(Plugin):
def load(self, ctx):
super(BasicPlugin, self).load(ctx)
self.tags = self.storage.guild('tags')
@Plugin.command('add', '<name:str> <value:str...>', group='tags')
def on_tags_add(self, event, name, value):
if name in self.tags:
return event.msg.reply('That tag already exists!')
self.tags[name] = value
return event.msg.reply(u':ok_hand: created the tag {}'.format(name), sanitize=True)
@Plugin.command('get', '<name:str>', group='tags')
def on_tags_get(self, event, name):
if name not in self.tags:
return event.msg.reply('That tag does not exist!')
return event.msg.reply(self.tags[name], sanitize=True)
@Plugin.command('delete', '<name:str>', group='tags', aliases=['del', 'rmv', 'remove'])
def on_tags_delete(self, event, name):
if name not in self.tags:
return event.msg.reply('That tag does not exist!')
del self.tags[name]
return event.msg.reply(u':ok_hand: I deleted the {} tag for you'.format(
name
), sanitize=True)

8
requirements.txt

@ -1,6 +1,6 @@
gevent==1.1.2
holster==1.0.11
gevent==1.2.1
holster==1.0.14
inflection==0.3.1
requests==2.11.1
requests==2.13.0
six==1.10.0
websocket-client==0.37.0
websocket-client==0.40.0

15
setup.py

@ -2,12 +2,25 @@ from setuptools import setup, find_packages
from disco import VERSION
def run_tests():
import unittest
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
with open('requirements.txt') as f:
requirements = f.readlines()
with open('README.md') as f:
readme = f.read()
extras_require = {
'voice': ['pynacl==1.1.2'],
'performance': ['erlpack==0.3.2'],
}
setup(
name='disco-py',
author='b1nzy',
@ -19,6 +32,8 @@ setup(
long_description=readme,
include_package_data=True,
install_requires=requirements,
extras_require=extras_require,
test_suite='setup.run_tests',
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: MIT License',

47
tests/test_bot.py

@ -0,0 +1,47 @@
from unittest import TestCase
from disco.client import ClientConfig, Client
from disco.bot.bot import Bot
from disco.bot.command import Command
class MockBot(Bot):
@property
def commands(self):
return getattr(self, '_commands', [])
class TestBot(TestCase):
def setUp(self):
self.client = Client(ClientConfig(
{'config': 'TEST_TOKEN'}
))
self.bot = MockBot(self.client)
def test_command_abbreviation(self):
groups = ['config', 'copy', 'copez', 'copypasta']
result = self.bot.compute_group_abbrev(groups)
self.assertDictEqual(result, {
'config': 'con',
'copypasta': 'copy',
'copez': 'cope',
})
def test_command_abbreivation_conflicting(self):
groups = ['cat', 'cap', 'caz', 'cas']
result = self.bot.compute_group_abbrev(groups)
self.assertDictEqual(result, {})
def test_many_commands(self):
self.bot._commands = [
Command(None, None, 'test{}'.format(i), '<test:str>')
for i in range(1000)
]
self.bot.compute_command_matches_re()
match = self.bot.command_matches_re.match('test5 123')
self.assertNotEqual(match, None)
match = self.bot._commands[0].compiled_regex.match('test0 123 456')
self.assertEqual(match.group(1).strip(), 'test0')
self.assertEqual(match.group(2).strip(), '123 456')

21
tests/test_channel.py

@ -0,0 +1,21 @@
from unittest import TestCase
from disco.types.channel import Channel, ChannelType
class TestChannel(TestCase):
def test_nsfw_channel(self):
channel = Channel(
name='nsfw-testing',
type=ChannelType.GUILD_TEXT)
self.assertTrue(channel.is_nsfw)
channel = Channel(
name='nsfw-testing',
type=ChannelType.GUILD_VOICE)
self.assertFalse(channel.is_nsfw)
channel = Channel(
name='nsfw_testing',
type=ChannelType.GUILD_TEXT)
self.assertFalse(channel.is_nsfw)

32
tests/test_embeds.py

@ -0,0 +1,32 @@
from unittest import TestCase
from datetime import datetime
from disco.types.message import MessageEmbed
class TestEmbeds(TestCase):
def test_empty_embed(self):
embed = MessageEmbed()
self.assertDictEqual(
embed.to_dict(),
{
'image': {},
'author': {},
'video': {},
'thumbnail': {},
'footer': {},
'fields': [],
'type': 'rich',
})
def test_embed(self):
embed = MessageEmbed(
title='Test Title',
description='Test Description',
url='https://test.url/'
)
obj = embed.to_dict()
self.assertEqual(obj['title'], 'Test Title')
self.assertEqual(obj['description'], 'Test Description')
self.assertEqual(obj['url'], 'https://test.url/')

42
tests/test_imports.py

@ -0,0 +1,42 @@
"""
This module tests that all of disco can be imported, mostly to help reduce issues
with untested code that will not even parse/run on Py2/3
"""
from disco.api.client import *
from disco.api.http import *
from disco.api.ratelimit import *
from disco.bot.bot import *
from disco.bot.command import *
from disco.bot.parser import *
from disco.bot.plugin import *
from disco.bot.storage import *
from disco.gateway.client import *
from disco.gateway.events import *
from disco.gateway.ipc import *
from disco.gateway.packets import *
# Not imported, GIPC is required but not provided by default
# from disco.gateway.sharder import *
from disco.types.base import *
from disco.types.channel import *
from disco.types.guild import *
from disco.types.invite import *
from disco.types.message import *
from disco.types.permissions import *
from disco.types.user import *
from disco.types.voice import *
from disco.types.webhook import *
from disco.util.backdoor import *
from disco.util.config import *
from disco.util.functional import *
from disco.util.hashmap import *
from disco.util.limiter import *
from disco.util.logging import *
from disco.util.serializer import *
from disco.util.snowflake import *
from disco.util.token import *
from disco.util.websocket import *
from disco.voice.client import *
from disco.voice.opus import *
from disco.voice.packets import *
from disco.voice.playable import *
from disco.voice.player import *
Loading…
Cancel
Save