diff --git a/.gitignore b/.gitignore index 2e0769c..016ef0a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ dist/ disco*.egg-info/ docs/_build storage.db - _book/ node_modules/ +storage.json +*.dca diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..4e6a402 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,13 @@ +language: python + +cache: pip + +python: + - '2.7' + - '3.3' + - '3.4' + - '3.5' + - '3.6' + - 'nightly' + +script: 'python setup.py test' diff --git a/disco/__init__.py b/disco/__init__.py index 262e0b7..c3f2301 100644 --- a/disco/__init__.py +++ b/disco/__init__.py @@ -1 +1 @@ -VERSION = '0.0.7' +VERSION = '0.0.8' diff --git a/disco/api/client.py b/disco/api/client.py index 0c00539..04b9b32 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -1,8 +1,10 @@ import six import json +import warnings from disco.api.http import Routes, HTTPClient from disco.util.logging import LoggingClass +from disco.util.sanitize import S from disco.types.user import User from disco.types.message import Message @@ -88,29 +90,56 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) return Message.create(self.client, r.json()) - def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None): + def channels_messages_create(self, channel, content=None, nonce=None, tts=False, + attachment=None, attachments=[], embed=None, sanitize=False): + payload = { - 'content': content, 'nonce': nonce, 'tts': tts, } + if attachment: + attachments = [attachment] + warnings.warn( + 'attachment kwarg has been deprecated, switch to using attachments with a list', + DeprecationWarning) + + if content: + if sanitize: + content = S(content) + payload['content'] = content + if embed: payload['embed'] = embed.to_dict() - if attachment: - r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data={'payload_json': json.dumps(payload)}, files={ - 'file': (attachment[0], attachment[1]) - }) + if attachments: + if len(attachments) > 1: + files = { + 'file{}'.format(idx): tuple(i) for idx, i in enumerate(attachments) + } + else: + files = { + 'file': tuple(attachments[0]), + } + + r = self.http( + Routes.CHANNELS_MESSAGES_CREATE, + dict(channel=channel), + data={'payload_json': json.dumps(payload)}, + files=files + ) else: r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload) return Message.create(self.client, r.json()) - def channels_messages_modify(self, channel, message, content, embed=None): - payload = { - 'content': content, - } + def channels_messages_modify(self, channel, message, content=None, embed=None, sanitize=False): + payload = {} + + if content: + if sanitize: + content = S(content) + payload['content'] = content if embed: payload['embed'] = embed.to_dict() @@ -285,6 +314,10 @@ class APIClient(LoggingClass): def guilds_roles_delete(self, guild, role): self.http(Routes.GUILDS_ROLES_DELETE, dict(guild=guild, role=role)) + def guilds_invites_list(self, guild): + r = self.http(Routes.GUILDS_INVITES_LIST, dict(guild=guild)) + return Invite.create_map(self.client, r.json()) + def guilds_webhooks_list(self, guild): r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild)) return Webhook.create_map(self.client, r.json()) @@ -295,11 +328,11 @@ class APIClient(LoggingClass): def guilds_emojis_create(self, guild, **kwargs): r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs) - return GuildEmoji.create(self.client, r.json()) + return GuildEmoji.create(self.client, r.json(), guild_id=guild) def guilds_emojis_modify(self, guild, emoji, **kwargs): r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs) - return GuildEmoji.create(self.client, r.json()) + return GuildEmoji.create(self.client, r.json(), guild_id=guild) def guilds_emojis_delete(self, guild, emoji): self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji)) @@ -311,6 +344,15 @@ class APIClient(LoggingClass): r = self.http(Routes.USERS_ME_PATCH, json=payload) return User.create(self.client, r.json()) + def users_me_guilds_delete(self, guild): + self.http(Routes.USERS_ME_GUILDS_DELETE, dict(guild=guild)) + + def users_me_dms_create(self, recipient_id): + r = self.http(Routes.USERS_ME_DMS_CREATE, json={ + 'recipient_id': recipient_id, + }) + return Channel.create(self.client, r.json()) + def invites_get(self, invite): r = self.http(Routes.INVITES_GET, dict(invite=invite)) return Invite.create(self.client, r.json()) diff --git a/disco/api/http.py b/disco/api/http.py index 088b69e..462ad8e 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -108,7 +108,7 @@ class Routes(object): USERS_ME_GET = (HTTPMethod.GET, USERS + '/@me') USERS_ME_PATCH = (HTTPMethod.PATCH, USERS + '/@me') USERS_ME_GUILDS_LIST = (HTTPMethod.GET, USERS + '/@me/guilds') - USERS_ME_GUILDS_LEAVE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}') + USERS_ME_GUILDS_DELETE = (HTTPMethod.DELETE, USERS + '/@me/guilds/{guild}') USERS_ME_DMS_LIST = (HTTPMethod.GET, USERS + '/@me/channels') USERS_ME_DMS_CREATE = (HTTPMethod.POST, USERS + '/@me/channels') USERS_ME_CONNECTIONS_LIST = (HTTPMethod.GET, USERS + '/@me/connections') @@ -176,7 +176,7 @@ class HTTPClient(LoggingClass): A simple HTTP client which wraps the requests library, adding support for Discords rate-limit headers, authorization, and request/response validation. """ - BASE_URL = 'https://discordapp.com/api/v6' + BASE_URL = 'https://discordapp.com/api/v7' MAX_RETRIES = 5 def __init__(self, token): @@ -189,13 +189,15 @@ class HTTPClient(LoggingClass): self.limiter = RateLimiter() self.headers = { - 'Authorization': 'Bot ' + token, 'User-Agent': 'DiscordBot (https://github.com/b1naryth1ef/disco {}) Python/{} requests/{}'.format( disco_version, py_version, requests_version), } + if token: + self.headers['Authorization'] = 'Bot ' + token + def __call__(self, route, args=None, **kwargs): return self.call(route, args, **kwargs) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 7693f50..18260b9 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -65,6 +65,7 @@ class BotConfig(Config): The directory plugin configuration is located within. """ levels = {} + plugins = [] plugin_config = {} commands_enabled = True @@ -81,12 +82,13 @@ class BotConfig(Config): commands_group_abbrev = True plugin_config_provider = None - plugin_config_format = 'yaml' + plugin_config_format = 'json' plugin_config_dir = 'config' storage_enabled = True - storage_provider = 'memory' - storage_config = {} + storage_fsync = True + storage_serializer = 'json' + storage_path = 'storage.json' class Bot(LoggingClass): @@ -195,35 +197,47 @@ class Bot(LoggingClass): Called when a plugin is loaded/unloaded to recompute internal state. """ if self.config.commands_group_abbrev: - self.compute_group_abbrev() + groups = set(command.group for command in self.commands if command.group) + self.group_abbrev = self.compute_group_abbrev(groups) self.compute_command_matches_re() - def compute_group_abbrev(self): + def compute_group_abbrev(self, groups): """ Computes all possible abbreviations for a command grouping. """ - self.group_abbrev = {} - groups = set(command.group for command in self.commands if command.group) - + # For the first pass, we just want to compute each groups possible + # abbreviations that don't conflict with eachother. + possible = {} for group in groups: - grp = group - while grp: - # If the group already exists, means someone else thought they - # could use it so we need yank it from them (and not use it) - if grp in list(six.itervalues(self.group_abbrev)): - self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp} + for index in range(len(group)): + current = group[:index] + if current in possible: + possible[current] = None else: - self.group_abbrev[group] = grp + possible[current] = group - grp = grp[:-1] + # Now, we want to compute the actual shortest abbreivation out of the + # possible ones + result = {} + for abbrev, group in six.iteritems(possible): + if not group: + continue + + if group in result: + if len(abbrev) < len(result[group]): + result[group] = abbrev + else: + result[group] = abbrev + + return result def compute_command_matches_re(self): """ Computes a single regex which matches all possible command combinations. """ commands = list(self.commands) - re_str = '|'.join(command.regex for command in commands) + re_str = '|'.join(command.regex(grouped=False) for command in commands) if re_str: self.command_matches_re = re.compile(re_str, re.I) else: @@ -267,7 +281,10 @@ class Bot(LoggingClass): if msg.guild: member = msg.guild.get_member(self.client.state.me) if member: - content = content.replace(member.mention, '', 1) + # If nickname is set, filter both the normal and nick mentions + if member.nick: + content = content.replace(member.mention, '', 1) + content = content.replace(member.user.mention, '', 1) else: content = content.replace(self.client.state.me.mention, '', 1) elif mention_everyone: @@ -355,10 +372,10 @@ class Bot(LoggingClass): if event.message.author.id == self.client.state.me.id: return - if self.config.commands_allow_edit: - self.last_message_cache[event.message.channel_id] = (event.message, False) + result = self.handle_message(event.message) - self.handle_message(event.message) + if self.config.commands_allow_edit: + self.last_message_cache[event.message.channel_id] = (event.message, result) def on_message_update(self, event): if self.config.commands_allow_edit: diff --git a/disco/bot/command.py b/disco/bot/command.py index 3d0398a..9a03ebe 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -6,6 +6,7 @@ from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property ARGS_REGEX = '(?: ((?:\n|.)*)$|$)' +ARGS_UNGROUPED_REGEX = '(?: (?:\n|.)*$|$)' USER_MENTION_RE = re.compile('<@!?([0-9]+)>') ROLE_MENTION_RE = re.compile('<@&([0-9]+)>') @@ -44,11 +45,11 @@ class CommandEvent(object): self.command = command self.msg = msg self.match = match - self.name = self.match.group(0) + self.name = self.match.group(1).strip() self.args = [] - if self.match.group(1): - self.args = [i for i in self.match.group(1).strip().split(' ') if i] + if self.match.group(2): + self.args = [i for i in self.match.group(2).strip().split(' ') if i] @property def codeblock(self): @@ -140,6 +141,10 @@ class Command(object): self.update(*args, **kwargs) + @property + def name(self): + return self.triggers[0] + def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) @@ -218,10 +223,9 @@ class Command(object): """ A compiled version of this command's regex. """ - return re.compile(self.regex, re.I) + return re.compile(self.regex(), re.I) - @property - def regex(self): + def regex(self, grouped=True): """ The regex string that defines/triggers this command. """ @@ -231,10 +235,13 @@ class Command(object): group = '' if self.group: if self.group in self.plugin.bot.group_abbrev: - group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group)) + group = '(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group)) else: group = self.group + ' ' - return '^{}(?:{})'.format(group, '|'.join(self.triggers)) + ARGS_REGEX + return ('^{}({})' if grouped else '^{}(?:{})').format( + group, + '|'.join(self.triggers) + ) + (ARGS_REGEX if grouped else ARGS_UNGROUPED_REGEX) def execute(self, event): """ @@ -247,9 +254,10 @@ class Command(object): Whether this command was successful """ if len(event.args) < self.args.required_length: - raise CommandError('{} requires {} arguments (passed {})'.format( + raise CommandError(u'Command {} requires {} arguments (`{}`) passed {}'.format( event.name, self.args.required_length, + self.raw_args, len(event.args) )) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 722abe6..e9cabfc 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -15,12 +15,6 @@ TYPE_MAP = { 'snowflake': lambda ctx, data: int(data), } -try: - import dateparser - TYPE_MAP['duration'] = lambda ctx, data: dateparser.parse(data, settings={'TIMEZONE': 'UTC'}) -except ImportError: - pass - def to_bool(ctx, data): if data in BOOL_OPTS: diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index aeed5b3..088f4da 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -209,15 +209,22 @@ class Plugin(LoggingClass, PluginDeco): def handle_exception(self, greenlet, event): pass - def wait_for_event(self, event_name, **kwargs): + def wait_for_event(self, event_name, conditional=None, **kwargs): result = AsyncResult() listener = None def _event_callback(event): for k, v in kwargs.items(): - if getattr(event, k) != v: + obj = event + for inst in k.split('__'): + obj = getattr(obj, inst) + + if obj != v: break else: + if conditional and not conditional(event): + return + listener.remove() return result.set(event) diff --git a/disco/bot/providers/__init__.py b/disco/bot/providers/__init__.py deleted file mode 100644 index 3ec985f..0000000 --- a/disco/bot/providers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import inspect -import importlib - -from .base import BaseProvider - - -def load_provider(name): - try: - mod = importlib.import_module('disco.bot.providers.' + name) - except ImportError: - mod = importlib.import_module(name) - - for entry in filter(inspect.isclass, map(lambda i: getattr(mod, i), dir(mod))): - if issubclass(entry, BaseProvider) and entry != BaseProvider: - return entry diff --git a/disco/bot/providers/base.py b/disco/bot/providers/base.py deleted file mode 100644 index 0f14f3d..0000000 --- a/disco/bot/providers/base.py +++ /dev/null @@ -1,134 +0,0 @@ -import six -import pickle - -from six.moves import map, UserDict - - -ROOT_SENTINEL = u'\u200B' -SEP_SENTINEL = u'\u200D' -OBJ_SENTINEL = u'\u200C' -CAST_SENTINEL = u'\u24EA' - - -def join_key(*args): - nargs = [] - for arg in args: - if not isinstance(arg, six.string_types): - arg = CAST_SENTINEL + pickle.dumps(arg) - nargs.append(arg) - return SEP_SENTINEL.join(nargs) - - -def true_key(key): - key = key.rsplit(SEP_SENTINEL, 1)[-1] - if key.startswith(CAST_SENTINEL): - return pickle.loads(key) - return key - - -class BaseProvider(object): - def __init__(self, config): - self.config = config - self.data = {} - - def exists(self, key): - return key in self.data - - def keys(self, other): - count = other.count(SEP_SENTINEL) + 1 - for key in self.data.keys(): - if key.startswith(other) and key.count(SEP_SENTINEL) == count: - yield key - - def get_many(self, keys): - for key in keys: - yield key, self.get(key) - - def get(self, key): - return self.data[key] - - def set(self, key, value): - self.data[key] = value - - def delete(self, key): - del self.data[key] - - def load(self): - pass - - def save(self): - pass - - def root(self): - return StorageDict(self) - - -class StorageDict(UserDict): - def __init__(self, parent_or_provider, key=None): - if isinstance(parent_or_provider, BaseProvider): - self.provider = parent_or_provider - self.parent = None - else: - self.parent = parent_or_provider - self.provider = self.parent.provider - self._key = key or ROOT_SENTINEL - - def keys(self): - return map(true_key, self.provider.keys(self.key)) - - def values(self): - for key in self.keys(): - yield self.provider.get(key) - - def items(self): - for key in self.keys(): - yield (true_key(key), self.provider.get(key)) - - def ensure(self, key, typ=dict): - if key not in self: - self[key] = typ() - return self[key] - - def update(self, obj): - for k, v in six.iteritems(obj): - self[k] = v - - @property - def data(self): - obj = {} - - for raw, value in self.provider.get_many(self.provider.keys(self.key)): - key = true_key(raw) - - if value == OBJ_SENTINEL: - value = self.__class__(self, key=key).data - obj[key] = value - return obj - - @property - def key(self): - if self.parent is not None: - return join_key(self.parent.key, self._key) - return self._key - - def __setitem__(self, key, value): - if isinstance(value, dict): - obj = self.__class__(self, key) - obj.update(value) - value = OBJ_SENTINEL - - self.provider.set(join_key(self.key, key), value) - - def __getitem__(self, key): - res = self.provider.get(join_key(self.key, key)) - - if res == OBJ_SENTINEL: - return self.__class__(self, key) - - return res - - def __delitem__(self, key): - return self.provider.delete(join_key(self.key, key)) - - def __contains__(self, key): - return self.provider.exists(join_key(self.key, key)) diff --git a/disco/bot/providers/disk.py b/disco/bot/providers/disk.py deleted file mode 100644 index af259e1..0000000 --- a/disco/bot/providers/disk.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -import gevent - -from disco.util.serializer import Serializer -from .base import BaseProvider - - -class DiskProvider(BaseProvider): - def __init__(self, config): - super(DiskProvider, self).__init__(config) - self.format = config.get('format', 'pickle') - self.path = config.get('path', 'storage') + '.' + self.format - self.fsync = config.get('fsync', False) - self.fsync_changes = config.get('fsync_changes', 1) - - self.autosave_task = None - self.change_count = 0 - - def autosave_loop(self, interval): - while True: - gevent.sleep(interval) - self.save() - - def _on_change(self): - if self.fsync: - self.change_count += 1 - - if self.change_count >= self.fsync_changes: - self.save() - self.change_count = 0 - - def load(self): - if not os.path.exists(self.path): - return - - if self.config.get('autosave', True): - self.autosave_task = gevent.spawn( - self.autosave_loop, - self.config.get('autosave_interval', 120)) - - with open(self.path, 'r') as f: - self.data = Serializer.loads(self.format, f.read()) - - def save(self): - with open(self.path, 'w') as f: - f.write(Serializer.dumps(self.format, self.data)) - - def set(self, key, value): - super(DiskProvider, self).set(key, value) - self._on_change() - - def delete(self, key): - super(DiskProvider, self).delete(key) - self._on_change() diff --git a/disco/bot/providers/memory.py b/disco/bot/providers/memory.py deleted file mode 100644 index 17ad47b..0000000 --- a/disco/bot/providers/memory.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base import BaseProvider - - -class MemoryProvider(BaseProvider): - pass diff --git a/disco/bot/providers/redis.py b/disco/bot/providers/redis.py deleted file mode 100644 index f5e1375..0000000 --- a/disco/bot/providers/redis.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import absolute_import - -import redis - -from itertools import izip - -from disco.util.serializer import Serializer -from .base import BaseProvider, SEP_SENTINEL - - -class RedisProvider(BaseProvider): - def __init__(self, config): - super(RedisProvider, self).__init__(config) - self.format = config.get('format', 'pickle') - self.conn = None - - def load(self): - self.conn = redis.Redis( - host=self.config.get('host', 'localhost'), - port=self.config.get('port', 6379), - db=self.config.get('db', 0)) - - def exists(self, key): - return self.conn.exists(key) - - def keys(self, other): - count = other.count(SEP_SENTINEL) + 1 - for key in self.conn.scan_iter(u'{}*'.format(other)): - key = key.decode('utf-8') - if key.count(SEP_SENTINEL) == count: - yield key - - def get_many(self, keys): - keys = list(keys) - if not len(keys): - raise StopIteration - - for key, value in izip(keys, self.conn.mget(keys)): - yield (key, Serializer.loads(self.format, value)) - - def get(self, key): - return Serializer.loads(self.format, self.conn.get(key)) - - def set(self, key, value): - self.conn.set(key, Serializer.dumps(self.format, value)) - - def delete(self, key): - self.conn.delete(key) diff --git a/disco/bot/providers/rocksdb.py b/disco/bot/providers/rocksdb.py deleted file mode 100644 index 986268d..0000000 --- a/disco/bot/providers/rocksdb.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import - -import six -import rocksdb - -from itertools import izip -from six.moves import map - -from disco.util.serializer import Serializer -from .base import BaseProvider, SEP_SENTINEL - - -class RocksDBProvider(BaseProvider): - def __init__(self, config): - super(RocksDBProvider, self).__init__(config) - self.format = config.get('format', 'pickle') - self.path = config.get('path', 'storage.db') - self.db = None - - @staticmethod - def k(k): - return bytes(k) if six.PY3 else str(k.encode('utf-8')) - - def load(self): - self.db = rocksdb.DB(self.path, rocksdb.Options(create_if_missing=True)) - - def exists(self, key): - return self.db.get(self.k(key)) is not None - - # TODO prefix extractor - def keys(self, other): - count = other.count(SEP_SENTINEL) + 1 - it = self.db.iterkeys() - it.seek_to_first() - - for key in it: - key = key.decode('utf-8') - if key.startswith(other) and key.count(SEP_SENTINEL) == count: - yield key - - def get_many(self, keys): - for key, value in izip(keys, self.db.multi_get(list(map(self.k, keys)))): - yield (key, Serializer.loads(self.format, value.decode('utf-8'))) - - def get(self, key): - return Serializer.loads(self.format, self.db.get(self.k(key)).decode('utf-8')) - - def set(self, key, value): - self.db.put(self.k(key), Serializer.dumps(self.format, value)) - - def delete(self, key): - self.db.delete(self.k(key)) diff --git a/disco/bot/storage.py b/disco/bot/storage.py index 812d79c..5b03a07 100644 --- a/disco/bot/storage.py +++ b/disco/bot/storage.py @@ -1,26 +1,87 @@ -from .providers import load_provider +import os +from six.moves import UserDict -class Storage(object): - def __init__(self, ctx, config): +from disco.util.hashmap import HashMap +from disco.util.serializer import Serializer + + +class StorageHashMap(HashMap): + def __init__(self, data): + self.data = data + + +class ContextAwareProxy(UserDict): + def __init__(self, ctx): self.ctx = ctx - self.config = config - self.provider = load_provider(config.provider)(config.config) - self.provider.load() - self.root = self.provider.root() @property - def plugin(self): - return self.root.ensure('plugins').ensure(self.ctx['plugin'].name) + def data(self): + return self.ctx() - @property - def guild(self): - return self.plugin.ensure('guilds').ensure(self.ctx['guild'].id) - @property - def channel(self): - return self.plugin.ensure('channels').ensure(self.ctx['channel'].id) +class StorageDict(UserDict): + def __init__(self, parent, data): + self._parent = parent + self.data = data - @property - def user(self): - return self.plugin.ensure('users').ensure(self.ctx['user'].id) + def update(self, other): + self.data.update(other) + self._parent._update() + + def __setitem__(self, key, value): + self.data[key] = value + self._parent._update() + + def __delitem__(self, key): + del self.data[key] + self._parent._update() + + +class Storage(object): + def __init__(self, ctx, config): + self._ctx = ctx + self._path = config.path + self._serializer = config.serializer + self._fsync = config.fsync + self._data = {} + + if os.path.exists(self._path): + with open(self._path, 'r') as f: + self._data = Serializer.loads(self._serializer, f.read()) + + def __getitem__(self, key): + if key not in self._data: + self._data[key] = {} + return StorageHashMap(StorageDict(self, self._data[key])) + + def _update(self): + if self._fsync: + self.save() + + def save(self): + if not self._path: + return + + with open(self._path, 'w') as f: + f.write(Serializer.dumps(self._serializer, self._data)) + + def guild(self, key): + return ContextAwareProxy( + lambda: self['_g{}:{}'.format(self._ctx['guild'].id, key)] + ) + + def channel(self, key): + return ContextAwareProxy( + lambda: self['_c{}:{}'.format(self._ctx['channel'].id, key)] + ) + + def plugin(self, key): + return ContextAwareProxy( + lambda: self['_p{}:{}'.format(self._ctx['plugin'].name, key)] + ) + + def user(self, key): + return ContextAwareProxy( + lambda: self['_u{}:{}'.format(self._ctx['user'].id, key)] + ) diff --git a/disco/client.py b/disco/client.py index a5791ee..478f226 100644 --- a/disco/client.py +++ b/disco/client.py @@ -15,13 +15,13 @@ from disco.util.backdoor import DiscoBackdoorServer class ClientConfig(Config): """ - Configuration for the :class:`Client`. + Configuration for the `Client`. Attributes ---------- token : str Discord authentication token, can be validated using the - :func:`disco.util.token.is_valid_token` function. + `disco.util.token.is_valid_token` function. shard_id : int The shard ID for the current client instance. shard_count : int @@ -53,32 +53,32 @@ class Client(LoggingClass): """ Class representing the base entry point that should be used in almost all implementation cases. This class wraps the functionality of both the REST API - (:class:`disco.api.client.APIClient`) and the realtime gateway API - (:class:`disco.gateway.client.GatewayClient`). + (`disco.api.client.APIClient`) and the realtime gateway API + (`disco.gateway.client.GatewayClient`). Parameters ---------- - config : :class:`ClientConfig` + config : `ClientConfig` Configuration for this client instance. Attributes ---------- - config : :class:`ClientConfig` + config : `ClientConfig` The runtime configuration for this client. - events : :class:`Emitter` + events : `Emitter` An emitter which emits Gateway events. - packets : :class:`Emitter` + packets : `Emitter` An emitter which emits Gateway packets. - state : :class:`State` + state : `State` The state tracking object. - api : :class:`APIClient` + api : `APIClient` The API client. - gw : :class:`GatewayClient` + gw : `GatewayClient` The gateway client. manhole_locals : dict Dictionary of local variables for each manhole connection. This can be modified to add/modify local variables. - manhole : Optional[:class:`BackdoorServer`] + manhole : Optional[`BackdoorServer`] Gevent backdoor server (if the manhole is enabled). """ def __init__(self, config): @@ -105,7 +105,21 @@ class Client(LoggingClass): localf=lambda: self.manhole_locals) self.manhole.start() - def update_presence(self, game=None, status=None, afk=False, since=0.0): + def update_presence(self, status, game=None, afk=False, since=0.0): + """ + Updates the current clients presence. + + Params + ------ + status : `user.Status` + The clients current status. + game : `user.Game` + If passed, the game object to set for the users presence. + afk : bool + Whether the client is currently afk. + since : float + How long the client has been afk for (in seconds). + """ if game and not isinstance(game, Game): raise TypeError('Game must be a Game model') @@ -126,12 +140,12 @@ class Client(LoggingClass): def run(self): """ - Run the client (e.g. the :class:`GatewayClient`) in a new greenlet. + Run the client (e.g. the `GatewayClient`) in a new greenlet. """ return gevent.spawn(self.gw.run) def run_forever(self): """ - Run the client (e.g. the :class:`GatewayClient`) in the current greenlet. + Run the client (e.g. the `GatewayClient`) in the current greenlet. """ return self.gw.run() diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 7386bb5..8ed210f 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -53,6 +53,8 @@ class GatewayClient(LoggingClass): self.session_id = None self.reconnects = 0 self.shutting_down = False + self.replaying = False + self.replayed_events = 0 # Cached gateway URL self._cached_gateway_url = None @@ -81,6 +83,8 @@ class GatewayClient(LoggingClass): obj = GatewayEvent.from_dispatch(self.client, packet) self.log.debug('Dispatching %s', obj.__class__.__name__) self.client.events.emit(obj.__class__.__name__, obj) + if self.replaying: + self.replayed_events += 1 def handle_heartbeat(self, _): self._send(OPCode.HEARTBEAT, self.seq) @@ -105,8 +109,9 @@ class GatewayClient(LoggingClass): self.reconnects = 0 def on_resumed(self, _): - self.log.info('Recieved RESUMED') + self.log.info('RESUME completed, replayed %s events', self.replayed_events) self.reconnects = 0 + self.replaying = False def connect_and_run(self, gateway_url=None): if not gateway_url: @@ -154,6 +159,7 @@ class GatewayClient(LoggingClass): def on_open(self): if self.seq and self.session_id: self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) + self.replaying = True self.send(OPCode.RESUME, { 'token': self.client.config.token, 'session_id': self.session_id, @@ -188,6 +194,8 @@ class GatewayClient(LoggingClass): self.log.info('WS Closed: shutting down') return + self.replaying = False + # Track reconnect attempts self.reconnects += 1 self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 5033b03..f2da195 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -49,6 +49,8 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): """ Create this GatewayEvent class from data and the client. """ + cls.raw_data = obj + # If this event is wrapping a model, pull its fields if hasattr(cls, '_wraps_model'): alias, model = cls._wraps_model @@ -151,6 +153,7 @@ class GuildCreate(GatewayEvent): and if None, this is a normal guild join event. """ unavailable = Field(bool) + presences = ListField(Presence) @property def created(self): @@ -607,6 +610,14 @@ class MessageReactionAdd(GatewayEvent): user_id = Field(snowflake) emoji = Field(MessageReactionEmoji) + def delete(self): + self.client.api.channels_messages_reactions_delete( + self.channel_id, + self.message_id, + self.emoji.to_string() if self.emoji.id else self.emoji.name, + self.user_id + ) + @property def channel(self): return self.client.state.channels.get(self.channel_id) diff --git a/disco/state.py b/disco/state.py index 689ab62..848665b 100644 --- a/disco/state.py +++ b/disco/state.py @@ -42,10 +42,10 @@ class StateConfig(Config): find they do not need and may be experiencing memory pressure can disable this feature entirely using this attribute. track_messages_size : int - The size of the deque for each channel. Using this you can calculate the - total number of possible :class:`StackMessage` objects kept in memory, - using: `total_mesages_size * total_channels`. This can be tweaked based - on usage to help prevent memory pressure. + The size of the messages deque for each channel. This value can be used + to calculate the total number of possible `StackMessage` objects kept in + memory, simply: `total_messages_size * total_channels`. This value can + be tweaked based on usage and to help prevent memory pressure. sync_guild_members : bool If true, guilds will be automatically synced when they are initially loaded or joined. Generally this setting is OK for smaller bots, however bots in over @@ -60,31 +60,31 @@ class StateConfig(Config): class State(object): """ The State class is used to track global state based on events emitted from - the :class:`GatewayClient`. State tracking is a core component of the Disco - client, providing the mechanism for most of the higher-level utility functions. + the `GatewayClient`. State tracking is a core component of the Disco client, + providing the mechanism for most of the higher-level utility functions. Attributes ---------- EVENTS : list(str) A list of all events the State object binds to - client : :class:`disco.client.Client` + client : `disco.client.Client` The Client instance this state is attached to - config : :class:`StateConfig` + config : `StateConfig` The configuration for this state instance - me : :class:`disco.types.user.User` + me : `User` The currently logged in user - dms : dict(snowflake, :class:`disco.types.channel.Channel`) + dms : dict(snowflake, `Channel`) Mapping of all known DM Channels - guilds : dict(snowflake, :class:`disco.types.guild.Guild`) + guilds : dict(snowflake, `Guild`) Mapping of all known/loaded Guilds - channels : dict(snowflake, :class:`disco.types.channel.Channel`) + channels : dict(snowflake, `Channel`) Weak mapping of all known/loaded Channels - users : dict(snowflake, :class:`disco.types.user.User`) + users : dict(snowflake, `User`) Weak mapping of all known/loaded Users - voice_states : dict(str, :class:`disco.types.voice.VoiceState`) + voice_states : dict(str, `VoiceState`) Weak mapping of all known/active Voice States - messages : Optional[dict(snowflake, :class:`deque`)] - Mapping of channel ids to deques containing :class:`StackMessage` objects + messages : Optional[dict(snowflake, deque)] + Mapping of channel ids to deques containing `StackMessage` objects """ EVENTS = [ 'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove', @@ -184,7 +184,11 @@ class State(object): self.channels.update(event.guild.channels) for member in six.itervalues(event.guild.members): - self.users[member.user.id] = member.user + if member.user.id not in self.users: + self.users[member.user.id] = member.user + + for presence in event.presences: + self.users[presence.user.id].presence = presence for voice_state in six.itervalues(event.guild.voice_states): self.voice_states[voice_state.session_id] = voice_state @@ -282,7 +286,8 @@ class State(object): for member in event.members: member.guild_id = guild.id guild.members[member.id] = member - self.users[member.id] = member.user + if member.id not in self.users: + self.users[member.id] = member.user def on_guild_role_create(self, event): if event.guild_id not in self.guilds: @@ -309,18 +314,33 @@ class State(object): if event.guild_id not in self.guilds: return + for emoji in event.emojis: + emoji.guild_id = event.guild_id + self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis}) def on_presence_update(self, event): - if event.user.id in self.users: - self.users[event.user.id].update(event.presence.user) - self.users[event.user.id].presence = event.presence - event.presence.user = self.users[event.user.id] - - if event.guild_id not in self.guilds: + # TODO: this is recursive, we hackfix in model, but its still lame ATM + user = event.presence.user + user.presence = event.presence + + # if we have the user tracked locally, we can just use the presence + # update to update both their presence and the cached user object. + if user.id in self.users: + self.users[user.id].update(user) + else: + # Otherwise this user does not exist in our local cache, so we can + # use this opportunity to add them. They will quickly fall out of + # scope and be deleted if they aren't used below + self.users[user.id] = user + + # Some updates come with a guild_id and roles the user is in, we should + # use this to update the guild member, but only if we have the guild + # cached. + if event.roles is UNSET or event.guild_id not in self.guilds: return - if event.user.id not in self.guilds[event.guild_id].members: + if user.id not in self.guilds[event.guild_id].members: return - self.guilds[event.guild_id].members[event.user.id].user.update(event.user) + self.guilds[event.guild_id].members[user.id].roles = event.roles diff --git a/disco/types/base.py b/disco/types/base.py index c3a061b..d396762 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -6,8 +6,9 @@ import functools from holster.enum import BaseEnumMeta, EnumAttr from datetime import datetime as real_datetime -from disco.util.functional import CachedSlotProperty +from disco.util.chains import Chainable from disco.util.hashmap import HashMap +from disco.util.functional import CachedSlotProperty DATETIME_FORMATS = [ '%Y-%m-%dT%H:%M:%S.%f', @@ -25,6 +26,9 @@ class Unset(object): def __nonzero__(self): return False + def __bool__(self): + return False + UNSET = Unset() @@ -270,15 +274,7 @@ class ModelMeta(type): return super(ModelMeta, mcs).__new__(mcs, name, parents, dct) -class AsyncChainable(object): - __slots__ = [] - - def after(self, delay): - gevent.sleep(delay) - return self - - -class Model(six.with_metaclass(ModelMeta, AsyncChainable)): +class Model(six.with_metaclass(ModelMeta, Chainable)): __slots__ = ['client'] def __init__(self, *args, **kwargs): @@ -294,6 +290,10 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): self.load(obj) self.validate() + def after(self, delay): + gevent.sleep(delay) + return self + def validate(self): pass diff --git a/disco/types/channel.py b/disco/types/channel.py index 311ca5c..39d2f47 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,3 +1,4 @@ +import re import six from six.moves import map @@ -11,6 +12,9 @@ from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.voice.client import VoiceClient +NSFW_RE = re.compile('^nsfw(-|$)') + + ChannelType = Enum( GUILD_TEXT=0, DM=1, @@ -179,6 +183,13 @@ class Channel(SlottedModel, Permissible): """ return self.type in (ChannelType.DM, ChannelType.GROUP_DM) + @property + def is_nsfw(self): + """ + Whether this channel is an NSFW channel. + """ + return self.type == ChannelType.GUILD_TEXT and NSFW_RE.match(self.name) + @property def is_voice(self): """ @@ -244,7 +255,7 @@ class Channel(SlottedModel, Permissible): def create_webhook(self, name=None, avatar=None): return self.client.api.channels_webhooks_create(self.id, name, avatar) - def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None): + def send_message(self, *args, **kwargs): """ Send a message in this channel. @@ -262,7 +273,7 @@ class Channel(SlottedModel, Permissible): :class:`disco.types.message.Message` The created message. """ - return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed) + return self.client.api.channels_messages_create(self.id, *args, **kwargs) def connect(self, *args, **kwargs): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index f93a88f..594d56a 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -9,7 +9,7 @@ from disco.util.functional import cached_property from disco.types.base import ( SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum, datetime ) -from disco.types.user import User, Presence +from disco.types.user import User from disco.types.voice import VoiceState from disco.types.channel import Channel from disco.types.message import Emoji @@ -23,6 +23,17 @@ VerificationLevel = Enum( HIGH=3, ) +ExplicitContentFilterLevel = Enum( + NONE=0, + WITHOUT_ROLES=1, + ALL=2 +) + +DefaultMessageNotificationsLevel = Enum( + ALL_MESSAGES=0, + ONLY_MENTIONS=1, +) + class GuildEmoji(Emoji): """ @@ -51,6 +62,12 @@ class GuildEmoji(Emoji): def __str__(self): return u'<:{}:{}>'.format(self.name, self.id) + def update(self, **kwargs): + return self.client.api.guilds_emojis_modify(self.guild_id, self.id, **kwargs) + + def delete(self): + return self.client.api.guilds_emojis_delete(self.guild_id, self.id) + @property def url(self): return 'https://discordapp.com/api/emojis/{}.png'.format(self.id) @@ -102,7 +119,7 @@ class Role(SlottedModel): @property def mention(self): - return '<@{}>'.format(self.id) + return '<@&{}>'.format(self.id) @cached_property def guild(self): @@ -289,6 +306,8 @@ class Guild(SlottedModel, Permissible): afk_timeout = Field(int) embed_enabled = Field(bool) verification_level = Field(enum(VerificationLevel)) + explicit_content_filter = Field(enum(ExplicitContentFilterLevel)) + default_message_notifications = Field(enum(DefaultMessageNotificationsLevel)) mfa_level = Field(int) features = ListField(str) members = AutoDictField(GuildMember, 'id') @@ -297,7 +316,6 @@ class Guild(SlottedModel, Permissible): emojis = AutoDictField(GuildEmoji, 'id') voice_states = AutoDictField(VoiceState, 'session_id') member_count = Field(int) - presences = ListField(Presence) synced = Field(bool, default=False) @@ -310,6 +328,10 @@ class Guild(SlottedModel, Permissible): self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) + @cached_property + def owner(self): + return self.members.get(self.owner_id) + def get_permissions(self, member): """ Get the permissions a user has in this guild. @@ -417,3 +439,12 @@ class Guild(SlottedModel, Permissible): def create_channel(self, *args, **kwargs): return self.client.api.guilds_channels_create(self.id, *args, **kwargs) + + def leave(self): + return self.client.api.users_me_guilds_delete(self.id) + + def get_invites(self): + return self.client.api.guilds_invites_list(self.id) + + def get_emojis(self): + return self.client.api.guilds_emojis_list(self.id) diff --git a/disco/types/invite.py b/disco/types/invite.py index 0cc4852..7f22a57 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,4 +1,4 @@ -from disco.types.base import SlottedModel, Field, datetime +from disco.types.base import SlottedModel, Field, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel @@ -40,7 +40,7 @@ class Invite(SlottedModel): created_at = Field(datetime) @classmethod - def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False): + def create_for_channel(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False): return channel.client.api.channels_invites_create( channel.id, max_age=max_age, diff --git a/disco/types/message.py b/disco/types/message.py index 9b33d11..0d1f219 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,5 +1,6 @@ import re import six +import warnings import functools import unicodedata @@ -315,6 +316,12 @@ class Message(SlottedModel): ) def create_reaction(self, emoji): + warnings.warn( + 'Message.create_reaction will be deprecated soon, use Message.add_reaction', + DeprecationWarning) + return self.add_reaction(emoji) + + def add_reaction(self, emoji): if isinstance(emoji, Emoji): emoji = emoji.to_string() self.client.api.channels_messages_reactions_create( diff --git a/disco/types/permissions.py b/disco/types/permissions.py index 6e4d9f3..6157bd9 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -40,6 +40,10 @@ class PermissionValue(object): self.value = value def can(self, *perms): + # Administrator permission overwrites all others + if self.administrator: + return True + for perm in perms: if isinstance(perm, EnumAttr): perm = perm.value diff --git a/disco/types/user.py b/disco/types/user.py index 3192abc..2777c5e 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -45,6 +45,9 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): def mention(self): return '<@{}>'.format(self.id) + def open_dm(self): + return self.client.api.users_me_dms_create(self.id) + def __str__(self): return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4)) diff --git a/disco/types/webhook.py b/disco/types/webhook.py index 4a630d3..b8930c7 100644 --- a/disco/types/webhook.py +++ b/disco/types/webhook.py @@ -1,8 +1,13 @@ +import re + from disco.types.base import SlottedModel, Field, snowflake from disco.types.user import User from disco.util.functional import cached_property +WEBHOOK_URL_RE = re.compile(r'\/api\/webhooks\/(\d+)\/(.[^/]+)') + + class Webhook(SlottedModel): id = Field(snowflake) guild_id = Field(snowflake) @@ -12,6 +17,19 @@ class Webhook(SlottedModel): avatar = Field(str) token = Field(str) + @classmethod + def execute_url(cls, url, **kwargs): + from disco.api.client import APIClient + + results = WEBHOOK_URL_RE.findall(url) + if len(results) != 1: + return Exception('Invalid Webhook URL') + + return cls(id=results[0][0], token=results[0][1]).execute( + client=APIClient(None), + **kwargs + ) + @cached_property def guild(self): return self.client.state.guilds.get(self.guild_id) @@ -32,10 +50,11 @@ class Webhook(SlottedModel): else: return self.client.api.webhooks_modify(self.id, name, avatar) - def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False): + def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False, client=None): # TODO: support file stuff properly + client = client or self.client.api - return self.client.api.webhooks_token_execute(self.id, self.token, { + return client.webhooks_token_execute(self.id, self.token, { 'content': content, 'username': username, 'avatar_url': avatar_url, diff --git a/disco/util/chains.py b/disco/util/chains.py new file mode 100644 index 0000000..5ea261f --- /dev/null +++ b/disco/util/chains.py @@ -0,0 +1,71 @@ +import gevent + +""" +Object.chain -> creates a chain where each action happens after the last + pass_result = False -> whether the result of the last action is passed, or the original + +Object.async_chain -> creates an async chain where each action happens at the same time +""" + + +class Chainable(object): + __slots__ = [] + + def chain(self, pass_result=True): + return Chain(self, pass_result=pass_result, async_=False) + + def async_chain(self): + return Chain(self, pass_result=False, async_=True) + + +class Chain(object): + def __init__(self, obj, pass_result=True, async_=False): + self._obj = obj + self._pass_result = pass_result + self._async = async_ + self._parts = [] + + @property + def obj(self): + if isinstance(self._obj, Chain): + return self._obj._next() + return self._obj + + def __getattr__(self, item): + func = getattr(self.obj, item) + if not func or not callable(func): + return func + + def _wrapped(*args, **kwargs): + inst = gevent.spawn(func, *args, **kwargs) + self._parts.append(inst) + + # If async, just return instantly + if self._async: + return self + + # Otherwise return a chain + return Chain(self) + return _wrapped + + def _next(self): + res = self._parts[0].get() + if self._pass_result: + return res + return self + + def then(self, func, *args, **kwargs): + inst = gevent.spawn(func, *args, **kwargs) + self._parts.append(inst) + if self._async: + return self + return Chain(self) + + def first(self): + return self._obj + + def get(self, timeout=None): + return gevent.wait(self._parts, timeout=timeout) + + def wait(self, timeout=None): + gevent.joinall(self._parts, timeout=None) diff --git a/disco/util/logging.py b/disco/util/logging.py index 68af8a8..61a7b42 100644 --- a/disco/util/logging.py +++ b/disco/util/logging.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +import warnings import logging @@ -9,10 +10,18 @@ LEVEL_OVERRIDES = { LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s' + def setup_logging(**kwargs): kwargs.setdefault('format', LOG_FORMAT) + # Setup warnings module correctly + warnings.simplefilter('always', DeprecationWarning) + logging.captureWarnings(True) + + # Pass through our basic configuration logging.basicConfig(**kwargs) + + # Override some noisey loggers for logger, level in LEVEL_OVERRIDES.items(): logging.getLogger(logger).setLevel(level) diff --git a/disco/util/sanitize.py b/disco/util/sanitize.py new file mode 100644 index 0000000..23a11e4 --- /dev/null +++ b/disco/util/sanitize.py @@ -0,0 +1,32 @@ +import re + + +# Zero width (non-rendering) space that can be used to escape mentions +ZERO_WIDTH_SPACE = u'\u200B' + +# A grave-looking character that can be used to escape codeblocks +MODIFIER_GRAVE_ACCENT = u'\u02CB' + +# Regex which matches all possible mention combinations, this may be over-zealous +# but its better safe than sorry. +MENTION_RE = re.compile('?') + + +def _re_sub_mention(mention): + mention = mention.group(1) + if '#' in mention: + return (u'#' + ZERO_WIDTH_SPACE).join(mention.split('#', 1)) + elif '@' in mention: + return (u'@' + ZERO_WIDTH_SPACE).join(mention.split('@', 1)) + else: + return mention + + +def S(text, escape_mentions=True, escape_codeblocks=False): + if escape_mentions: + text = MENTION_RE.sub(_re_sub_mention, text) + + if escape_codeblocks: + text = text.replace('`', MODIFIER_GRAVE_ACCENT) + + return text diff --git a/disco/util/snowflake.py b/disco/util/snowflake.py index b2f512f..a9aeb15 100644 --- a/disco/util/snowflake.py +++ b/disco/util/snowflake.py @@ -2,6 +2,7 @@ import six from datetime import datetime +UNIX_EPOCH = datetime(1970, 1, 1) DISCORD_EPOCH = 1420070400000 @@ -20,6 +21,14 @@ def to_unix_ms(snowflake): return (int(snowflake) >> 22) + DISCORD_EPOCH +def from_datetime(date): + return from_timestamp((date - UNIX_EPOCH).total_seconds()) + + +def from_timestamp(ts): + return long(ts * 1000.0 - DISCORD_EPOCH) << 22 + + def to_snowflake(i): if isinstance(i, six.integer_types): return i diff --git a/disco/voice/__init__.py b/disco/voice/__init__.py index e69de29..b4a7f6c 100644 --- a/disco/voice/__init__.py +++ b/disco/voice/__init__.py @@ -0,0 +1,3 @@ +from disco.voice.client import * +from disco.voice.player import * +from disco.voice.playable import * diff --git a/disco/voice/client.py b/disco/voice/client.py index 69fe9d5..6f2eae3 100644 --- a/disco/voice/client.py +++ b/disco/voice/client.py @@ -1,8 +1,15 @@ +from __future__ import print_function + import gevent import socket import struct import time +try: + import nacl.secret +except ImportError: + print('WARNING: nacl is not installed, voice support is disabled') + from holster.enum import Enum from holster.emitter import Emitter @@ -22,11 +29,6 @@ VoiceState = Enum( VOICE_CONNECTED=6, ) -# TODO: -# - player implementation -# - encryption -# - cleanup - class VoiceException(Exception): def __init__(self, msg, client): @@ -38,12 +40,40 @@ class UDPVoiceClient(LoggingClass): def __init__(self, vc): super(UDPVoiceClient, self).__init__() self.vc = vc + + # The underlying UDP socket self.conn = None + + # Connection information self.ip = None self.port = None + self.run_task = None self.connected = False + def send_frame(self, frame, sequence=None, timestamp=None): + # Convert the frame to a bytearray + frame = bytearray(frame) + + # First, pack the header (TODO: reuse bytearray?) + header = bytearray(24) + header[0] = 0x80 + header[1] = 0x78 + struct.pack_into('>H', header, 2, sequence or self.vc.sequence) + struct.pack_into('>I', header, 4, timestamp or self.vc.timestamp) + struct.pack_into('>i', header, 8, self.vc.ssrc) + + # Now encrypt the payload with the nonce as a header + raw = self.vc.secret_box.encrypt(bytes(frame), bytes(header)).ciphertext + + # Send the header (sans nonce padding) plus the payload + self.send(header[:12] + raw) + + # Increment our sequence counter + self.vc.sequence += 1 + if self.vc.sequence >= 65535: + self.vc.sequence = 0 + def run(self): while True: self.conn.recvfrom(4096) @@ -54,26 +84,29 @@ class UDPVoiceClient(LoggingClass): def disconnect(self): self.run_task.kill() - def connect(self, host, port, timeout=10): + def connect(self, host, port, timeout=10, addrinfo=None): self.ip = socket.gethostbyname(host) self.port = port self.conn = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # Send discovery packet - packet = bytearray(70) - struct.pack_into('>I', packet, 0, self.vc.ssrc) - self.send(packet) + if addrinfo: + ip, port = addrinfo + else: + # Send discovery packet + packet = bytearray(70) + struct.pack_into('>I', packet, 0, self.vc.ssrc) + self.send(packet) - # Wait for a response - try: - data, addr = gevent.spawn(lambda: self.conn.recvfrom(70)).get(timeout=timeout) - except gevent.Timeout: - return (None, None) + # Wait for a response + try: + data, addr = gevent.spawn(lambda: self.conn.recvfrom(70)).get(timeout=timeout) + except gevent.Timeout: + return (None, None) - # Read IP and port - ip = str(data[4:]).split('\x00', 1)[0] - port = struct.unpack('0]/bestaudio/best'}) + + if self._url: + obj = ydl.extract_info(self._url, download=False, process=False) + if 'entries' in obj: + self._ie_info = obj['entries'][0] + else: + self._ie_info = obj + + self._info = ydl.process_ie_result(self._ie_info, download=False) + return self._info + + @property + def _metadata(self): + return self.info + + @classmethod + def many(cls, url, *args, **kwargs): + import youtube_dl + + ydl = youtube_dl.YoutubeDL({'format': 'webm[abr>0]/bestaudio/best'}) + info = ydl.extract_info(url, download=False, process=False) + + if 'entries' not in info: + yield cls(ie_info=info, *args, **kwargs) + raise StopIteration + + for item in info['entries']: + yield cls(ie_info=item, *args, **kwargs) + + @property + def source(self): + return self.info['url'] + + +class BufferedOpusEncoderPlayable(BasePlayable, OpusEncoder, AbstractOpus): + def __init__(self, source, *args, **kwargs): + self.source = source + self.frames = Queue(kwargs.pop('queue_size', 4096)) + + # Call the AbstractOpus constructor, as we need properties it sets + AbstractOpus.__init__(self, *args, **kwargs) + + # Then call the OpusEncoder constructor, which requires some properties + # that AbstractOpus sets up + OpusEncoder.__init__(self, self.sampling_rate, self.channels) + + # Spawn the encoder loop + gevent.spawn(self._encoder_loop) + + def _encoder_loop(self): + while self.source: + raw = self.source.read(self.frame_size) + if len(raw) < self.frame_size: + break + + self.frames.put(self.encode(raw, self.samples_per_frame)) + gevent.idle() + self.source = None + self.frames.put(None) + + def next_frame(self): + return self.frames.get() + + +class DCADOpusEncoderPlayable(BasePlayable, AbstractOpus, OpusEncoder): + def __init__(self, source, *args, **kwargs): + self.source = source + self.command = kwargs.pop('command', 'dcad') + super(DCADOpusEncoderPlayable, self).__init__(*args, **kwargs) + + self._done = False + self._proc = None + + @property + def proc(self): + if not self._proc: + source = obj = self.source.fileobj() + if not hasattr(obj, 'fileno'): + source = subprocess.PIPE + + self._proc = subprocess.Popen([ + self.command, + '--channels', str(self.channels), + '--rate', str(self.sampling_rate), + '--size', str(self.samples_per_frame), + '--bitrate', '128', + '--fec', + '--packet-loss-percent', '30', + '--input', 'pipe:0', + '--output', 'pipe:1', + ], stdin=source, stdout=subprocess.PIPE) + + def writer(): + while True: + data = obj.read(2048) + if len(data) > 0: + self._proc.stdin.write(data) + if len(data) < 2048: + break + + if source == subprocess.PIPE: + gevent.spawn(writer) + return self._proc + + def next_frame(self): + if self._done: + return None + + header = self.proc.stdout.read(OPUS_HEADER_SIZE) + if len(header) < OPUS_HEADER_SIZE: + self._done = True + return + + size = struct.unpack(' MAX_TIMESTAMP: + self.client.timestamp = 0 + + frame = item.next_frame() + if frame is None: + return + + next_time = start + 0.02 * loops + delay = max(0, 0.02 + (next_time - time.time())) + gevent.sleep(delay) + + def run(self): + self.client.set_speaking(True) + + while self.playing: + self.now_playing = self.queue.get() + + self.events.emit(self.Events.START_PLAY, self.now_playing) + self.play_task = gevent.spawn(self.play, self.now_playing) + self.play_task.join() + self.events.emit(self.Events.STOP_PLAY, self.now_playing) + + if self.client.state == VoiceState.DISCONNECTED: + self.playing = False + self.complete.set() + return + + self.client.set_speaking(False) + self.disconnect() diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 43d28b2..4a1f4dc 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -4,7 +4,5 @@ * [Installation and Setup](INSTALLATION.md) * [Building a Bot](BUILDING_A_BOT.md) * API Docs - * [Client](CLIENT.md) - * Types - * [Message](types/MESSAGE.md) + * [Client](api/disco_client.md) diff --git a/docs/types/MESSAGE.md b/docs/types/MESSAGE.md deleted file mode 100644 index 291ec5d..0000000 --- a/docs/types/MESSAGE.md +++ /dev/null @@ -1,3 +0,0 @@ -# Message - -TODO diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 57faae0..a48b21b 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -50,7 +50,7 @@ class BasicPlugin(Plugin): if not users: event.msg.reply("Couldn't find user for your query: `{}`".format(query)) elif len(users) > 1: - event.msg.reply('I found too many userse ({}) for your query: `{}`'.format(len(users), query)) + event.msg.reply('I found too many users ({}) for your query: `{}`'.format(len(users), query)) else: user = users[0] parts = [] diff --git a/examples/music.py b/examples/music.py new file mode 100644 index 0000000..2c376c3 --- /dev/null +++ b/examples/music.py @@ -0,0 +1,52 @@ +from disco.bot import Plugin +from disco.bot.command import CommandError +from disco.voice.player import Player +from disco.voice.playable import YoutubeDLInput, BufferedOpusEncoderPlayable +from disco.voice.client import VoiceException + + +class MusicPlugin(Plugin): + def load(self, ctx): + super(MusicPlugin, self).load(ctx) + self.guilds = {} + + @Plugin.command('join') + def on_join(self, event): + if event.guild.id in self.guilds: + return event.msg.reply("I'm already playing music here.") + + state = event.guild.get_member(event.author).get_voice_state() + if not state: + return event.msg.reply('You must be connected to voice to use that command.') + + try: + client = state.channel.connect() + except VoiceException as e: + return event.msg.reply('Failed to connect to voice: `{}`'.format(e)) + + self.guilds[event.guild.id] = Player(client) + self.guilds[event.guild.id].complete.wait() + del self.guilds[event.guild.id] + + def get_player(self, guild_id): + if guild_id not in self.guilds: + raise CommandError("I'm not currently playing music here.") + return self.guilds.get(guild_id) + + @Plugin.command('leave') + def on_leave(self, event): + player = self.get_player(event.guild.id) + player.disconnect() + + @Plugin.command('play', '') + def on_play(self, event, url): + item = YoutubeDLInput(url).pipe(BufferedOpusEncoderPlayable) + self.get_player(event.guild.id).queue.put(item) + + @Plugin.command('pause') + def on_pause(self, event): + self.get_player(event.guild.id).pause() + + @Plugin.command('resume') + def on_resume(self, event): + self.get_player(event.guild.id).resume() diff --git a/examples/storage.py b/examples/storage.py new file mode 100644 index 0000000..c8e5ce3 --- /dev/null +++ b/examples/storage.py @@ -0,0 +1,33 @@ +from disco.bot import Plugin + + +class BasicPlugin(Plugin): + def load(self, ctx): + super(BasicPlugin, self).load(ctx) + self.tags = self.storage.guild('tags') + + @Plugin.command('add', ' ', group='tags') + def on_tags_add(self, event, name, value): + if name in self.tags: + return event.msg.reply('That tag already exists!') + + self.tags[name] = value + return event.msg.reply(u':ok_hand: created the tag {}'.format(name), sanitize=True) + + @Plugin.command('get', '', group='tags') + def on_tags_get(self, event, name): + if name not in self.tags: + return event.msg.reply('That tag does not exist!') + + return event.msg.reply(self.tags[name], sanitize=True) + + @Plugin.command('delete', '', group='tags', aliases=['del', 'rmv', 'remove']) + def on_tags_delete(self, event, name): + if name not in self.tags: + return event.msg.reply('That tag does not exist!') + + del self.tags[name] + + return event.msg.reply(u':ok_hand: I deleted the {} tag for you'.format( + name + ), sanitize=True) diff --git a/requirements.txt b/requirements.txt index 3b0894b..a0de4f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -gevent==1.1.2 -holster==1.0.11 +gevent==1.2.1 +holster==1.0.14 inflection==0.3.1 -requests==2.11.1 +requests==2.13.0 six==1.10.0 -websocket-client==0.37.0 +websocket-client==0.40.0 diff --git a/setup.py b/setup.py index 7750612..2b2a0f1 100644 --- a/setup.py +++ b/setup.py @@ -2,12 +2,25 @@ from setuptools import setup, find_packages from disco import VERSION + +def run_tests(): + import unittest + test_loader = unittest.TestLoader() + test_suite = test_loader.discover('tests', pattern='test_*.py') + return test_suite + + with open('requirements.txt') as f: requirements = f.readlines() with open('README.md') as f: readme = f.read() +extras_require = { + 'voice': ['pynacl==1.1.2'], + 'performance': ['erlpack==0.3.2'], +} + setup( name='disco-py', author='b1nzy', @@ -19,6 +32,8 @@ setup( long_description=readme, include_package_data=True, install_requires=requirements, + extras_require=extras_require, + test_suite='setup.run_tests', classifiers=[ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: MIT License', diff --git a/tests/test_bot.py b/tests/test_bot.py new file mode 100644 index 0000000..566de0a --- /dev/null +++ b/tests/test_bot.py @@ -0,0 +1,47 @@ +from unittest import TestCase + +from disco.client import ClientConfig, Client +from disco.bot.bot import Bot +from disco.bot.command import Command + + +class MockBot(Bot): + @property + def commands(self): + return getattr(self, '_commands', []) + + +class TestBot(TestCase): + def setUp(self): + self.client = Client(ClientConfig( + {'config': 'TEST_TOKEN'} + )) + self.bot = MockBot(self.client) + + def test_command_abbreviation(self): + groups = ['config', 'copy', 'copez', 'copypasta'] + result = self.bot.compute_group_abbrev(groups) + self.assertDictEqual(result, { + 'config': 'con', + 'copypasta': 'copy', + 'copez': 'cope', + }) + + def test_command_abbreivation_conflicting(self): + groups = ['cat', 'cap', 'caz', 'cas'] + result = self.bot.compute_group_abbrev(groups) + self.assertDictEqual(result, {}) + + def test_many_commands(self): + self.bot._commands = [ + Command(None, None, 'test{}'.format(i), '') + for i in range(1000) + ] + + self.bot.compute_command_matches_re() + match = self.bot.command_matches_re.match('test5 123') + self.assertNotEqual(match, None) + + match = self.bot._commands[0].compiled_regex.match('test0 123 456') + self.assertEqual(match.group(1).strip(), 'test0') + self.assertEqual(match.group(2).strip(), '123 456') diff --git a/tests/test_channel.py b/tests/test_channel.py new file mode 100644 index 0000000..56d5227 --- /dev/null +++ b/tests/test_channel.py @@ -0,0 +1,21 @@ +from unittest import TestCase + +from disco.types.channel import Channel, ChannelType + + +class TestChannel(TestCase): + def test_nsfw_channel(self): + channel = Channel( + name='nsfw-testing', + type=ChannelType.GUILD_TEXT) + self.assertTrue(channel.is_nsfw) + + channel = Channel( + name='nsfw-testing', + type=ChannelType.GUILD_VOICE) + self.assertFalse(channel.is_nsfw) + + channel = Channel( + name='nsfw_testing', + type=ChannelType.GUILD_TEXT) + self.assertFalse(channel.is_nsfw) diff --git a/tests/test_embeds.py b/tests/test_embeds.py new file mode 100644 index 0000000..ef60361 --- /dev/null +++ b/tests/test_embeds.py @@ -0,0 +1,32 @@ +from unittest import TestCase + +from datetime import datetime +from disco.types.message import MessageEmbed + + +class TestEmbeds(TestCase): + def test_empty_embed(self): + embed = MessageEmbed() + self.assertDictEqual( + embed.to_dict(), + { + 'image': {}, + 'author': {}, + 'video': {}, + 'thumbnail': {}, + 'footer': {}, + 'fields': [], + 'type': 'rich', + }) + + def test_embed(self): + embed = MessageEmbed( + title='Test Title', + description='Test Description', + url='https://test.url/' + ) + obj = embed.to_dict() + self.assertEqual(obj['title'], 'Test Title') + self.assertEqual(obj['description'], 'Test Description') + self.assertEqual(obj['url'], 'https://test.url/') + diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..8257c0a --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,42 @@ +""" +This module tests that all of disco can be imported, mostly to help reduce issues +with untested code that will not even parse/run on Py2/3 +""" +from disco.api.client import * +from disco.api.http import * +from disco.api.ratelimit import * +from disco.bot.bot import * +from disco.bot.command import * +from disco.bot.parser import * +from disco.bot.plugin import * +from disco.bot.storage import * +from disco.gateway.client import * +from disco.gateway.events import * +from disco.gateway.ipc import * +from disco.gateway.packets import * +# Not imported, GIPC is required but not provided by default +# from disco.gateway.sharder import * +from disco.types.base import * +from disco.types.channel import * +from disco.types.guild import * +from disco.types.invite import * +from disco.types.message import * +from disco.types.permissions import * +from disco.types.user import * +from disco.types.voice import * +from disco.types.webhook import * +from disco.util.backdoor import * +from disco.util.config import * +from disco.util.functional import * +from disco.util.hashmap import * +from disco.util.limiter import * +from disco.util.logging import * +from disco.util.serializer import * +from disco.util.snowflake import * +from disco.util.token import * +from disco.util.websocket import * +from disco.voice.client import * +from disco.voice.opus import * +from disco.voice.packets import * +from disco.voice.playable import * +from disco.voice.player import *