From b44e0e1b9bdef5f96a7946b7cd79f598f4c091db Mon Sep 17 00:00:00 2001 From: Ethan Date: Fri, 14 Oct 2016 15:16:28 -0400 Subject: [PATCH 01/91] Fixed Route Typo (#8) --- disco/api/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/api/client.py b/disco/api/client.py index ac83dea..3945e76 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -236,7 +236,7 @@ class APIClient(LoggingClass): return Webhook.create(self.client, r.json()) def webhooks_token_delete(self, webhook, token): - self.http(Routes.WEBHOOKS_TOKEN_DLEETE, dict(webhook=webhook, token=token)) + self.http(Routes.WEBHOOKS_TOKEN_DELETE, dict(webhook=webhook, token=token)) def webhooks_token_execute(self, webhook, token, data, wait=False): obj = self.http( From ce29836e84f711288008e766d63e426f7ea3ba84 Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 17 Oct 2016 19:36:03 -0500 Subject: [PATCH 02/91] First stab at IPC based auto sharding --- disco/api/client.py | 11 ++++++-- disco/api/http.py | 2 ++ disco/bot/bot.py | 4 +-- disco/cli.py | 6 +++++ disco/client.py | 2 +- disco/gateway/client.py | 24 +++++++++++------ disco/gateway/ipc/__init__.py | 0 disco/gateway/ipc/gipc.py | 50 +++++++++++++++++++++++++++++++++++ disco/gateway/sharder.py | 31 ++++++++++++++++++++++ 9 files changed, 117 insertions(+), 13 deletions(-) create mode 100644 disco/gateway/ipc/__init__.py create mode 100644 disco/gateway/ipc/gipc.py create mode 100644 disco/gateway/sharder.py diff --git a/disco/api/client.py b/disco/api/client.py index 3945e76..b9cc813 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -32,9 +32,13 @@ class APIClient(LoggingClass): self.client = client self.http = HTTPClient(self.client.config.token) - def gateway(self, version, encoding): + def gateway_get(self): data = self.http(Routes.GATEWAY_GET).json() - return data['url'] + '?v={}&encoding={}'.format(version, encoding) + return data + + def gateway_bot_get(self): + data = self.http(Routes.GATEWAY_BOT_GET).json() + return data def channels_get(self, channel): r = self.http(Routes.CHANNELS_GET, dict(channel=channel)) @@ -48,6 +52,9 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel)) return Channel.create(self.client, r.json()) + def channels_typing(self, channel): + self.http(Routes.CHANNELS_TYPING, dict(channel=channel)) + def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50): r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional( around=around, diff --git a/disco/api/http.py b/disco/api/http.py index c1930bd..4757da4 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -25,12 +25,14 @@ class Routes(object): """ # Gateway GATEWAY_GET = (HTTPMethod.GET, '/gateway') + GATEWAY_BOT_GET = (HTTPMethod.GET, '/gateway/bot') # Channels CHANNELS = '/channels/{channel}' CHANNELS_GET = (HTTPMethod.GET, CHANNELS) CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS) CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS) + CHANNELS_TYPING = (HTTPMethod.POST, CHANNELS + '/typing') CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages') CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages') diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 6d5f250..99ac0b5 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -430,8 +430,8 @@ class Bot(object): def load_plugin_config(self, cls): name = cls.__name__.lower() - if name.startswith('plugin'): - name = name[6:] + if name.endswith('plugin'): + name = name[:-6] path = os.path.join( self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format diff --git a/disco/cli.py b/disco/cli.py index 951fd96..27cc220 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -18,6 +18,7 @@ parser.add_argument('--config', help='Configuration file', default='config.yaml' parser.add_argument('--token', help='Bot Authentication Token', default=None) parser.add_argument('--shard-count', help='Total number of shards', default=None) parser.add_argument('--shard-id', help='Current shard number/id', default=None) +parser.add_argument('--shard-auto', help='Automatically run all shards', action='store_true', default=False) parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None) parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None) parser.add_argument('--encoder', help='encoder for gateway data', default=None) @@ -41,6 +42,7 @@ def disco_main(run=False): from disco.client import Client, ClientConfig from disco.bot import Bot, BotConfig + from disco.gateway.sharder import AutoSharder from disco.util.token import is_valid_token if os.path.exists(args.config): @@ -56,6 +58,10 @@ def disco_main(run=False): print('Invalid token passed') return + if args.shard_auto: + AutoSharder(config).run() + return + client = Client(config) bot = None diff --git a/disco/client.py b/disco/client.py index 54dacd5..496aa7a 100644 --- a/disco/client.py +++ b/disco/client.py @@ -37,7 +37,7 @@ class ClientConfig(LoggingClass, Config): shard_id = 0 shard_count = 1 - manhole_enable = True + manhole_enable = False manhole_bind = ('127.0.0.1', 8484) encoder = 'json' diff --git a/disco/gateway/client.py b/disco/gateway/client.py index b3f8012..61f7dfd 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -17,7 +17,7 @@ class GatewayClient(LoggingClass): GATEWAY_VERSION = 6 MAX_RECONNECTS = 5 - def __init__(self, client, encoder='json'): + def __init__(self, client, encoder='json', ipc=None): super(GatewayClient, self).__init__() self.client = client self.encoder = ENCODERS[encoder] @@ -25,6 +25,11 @@ class GatewayClient(LoggingClass): self.events = client.events self.packets = client.packets + # IPC for shards + if ipc: + self.shards = ipc.get_shards() + self.ipc = ipc + # Its actually 60, 120 but lets give ourselves a buffer self.limiter = SimpleLimiter(60, 130) @@ -98,14 +103,17 @@ class GatewayClient(LoggingClass): self.session_id = ready.session_id self.reconnects = 0 - def connect_and_run(self): - if not self._cached_gateway_url: - self._cached_gateway_url = self.client.api.gateway( - version=self.GATEWAY_VERSION, - encoding=self.encoder.TYPE) + def connect_and_run(self, gateway_url=None): + if not gateway_url: + if not self._cached_gateway_url: + self._cached_gateway_url = self.client.api.gateway_get()['url'] + + gateway_url = self._cached_gateway_url + + gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE) - self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) - self.ws = Websocket(self._cached_gateway_url) + self.log.info('Opening websocket connection to URL `%s`', gateway_url) + self.ws = Websocket(gateway_url) self.ws.emitter.on('on_open', self.on_open) self.ws.emitter.on('on_error', self.on_error) self.ws.emitter.on('on_close', self.on_close) diff --git a/disco/gateway/ipc/__init__.py b/disco/gateway/ipc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/disco/gateway/ipc/gipc.py b/disco/gateway/ipc/gipc.py new file mode 100644 index 0000000..b9468d0 --- /dev/null +++ b/disco/gateway/ipc/gipc.py @@ -0,0 +1,50 @@ +import random +import gipc +import gevent +import string +import weakref + + +def get_random_str(size): + return ''.join([random.choice(string.ascii_printable) for _ in range(size)]) + + +class GIPCProxy(object): + def __init__(self, pipe): + self.pipe = pipe + self.results = weakref.WeakValueDictionary() + gevent.spawn(self.read_loop) + + def read_loop(self): + while True: + nonce, data = self.pipe.get() + if nonce in self.results: + self.results[nonce].set(data) + + def __getattr__(self, name): + def wrapper(*args, **kwargs): + nonce = get_random_str() + self.results[nonce] = gevent.event.AsyncResult() + self.pipe.put(nonce, name, args, kwargs) + return self.results[nonce] + return wrapper + + +class GIPCObject(object): + def __init__(self, inst, pipe): + self.inst = inst + self.pipe = pipe + gevent.spawn(self.read_loop) + + def read_loop(self): + while True: + nonce, func, args, kwargs = self.pipe.get() + func = getattr(self.inst, func) + self.pipe.put((nonce, func(*args, **kwargs))) + +class IPC(object): + def __init__(self, sharder): + self.sharder = sharder + + def get_shards(self): + return {} diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py new file mode 100644 index 0000000..fb89875 --- /dev/null +++ b/disco/gateway/sharder.py @@ -0,0 +1,31 @@ +import gipc + +from disco.client import Client +from disco.bot import Bot, BotConfig +from disco.api.client import APIClient +from disco.gateway.ipc.gipc import GIPCObject, GIPCProxy + + +def run_shard(config, id, pipe): + config.shard_id = id + client = Client(config) + bot = Bot(client, BotConfig(config.bot)) + GIPCObject(bot, pipe) + bot.run_forever() + + +class AutoSharder(object): + def __init__(self, config): + self.config = config + self.client = APIClient(config.token) + self.shards = {} + self.config.shard_count = self.client.gateway_bot_get()['shards'] + + def run(self): + for shard in range(self.shard_count): + self.start_shard(shard) + + def start_shard(self, id): + cpipe, ppipe = gipc.pipe(duplex=True) + gipc.start_process(run_shard, (self.config, id, cpipe)) + self.shards[id] = GIPCProxy(ppipe) From aee7fa13e100d7ee2515d2152756bb0ca2405cce Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 17 Oct 2016 21:46:17 -0500 Subject: [PATCH 03/91] Refine autosharding/IPC --- disco/api/client.py | 4 +- disco/cli.py | 4 +- disco/client.py | 2 +- disco/gateway/ipc/gipc.py | 100 +++++++++++++++++++++++++++----------- disco/gateway/sharder.py | 48 ++++++++++++++++-- disco/types/message.py | 48 ++++++++++++++++++ 6 files changed, 168 insertions(+), 38 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index b9cc813..6705f16 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -26,11 +26,11 @@ class APIClient(LoggingClass): An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits the models with the returned data. """ - def __init__(self, client): + def __init__(self, token, client=None): super(APIClient, self).__init__() self.client = client - self.http = HTTPClient(self.client.config.token) + self.http = HTTPClient(token) def gateway_get(self): data = self.http(Routes.GATEWAY_GET).json() diff --git a/disco/cli.py b/disco/cli.py index 27cc220..f2ef504 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -25,8 +25,6 @@ parser.add_argument('--encoder', help='encoder for gateway data', default=None) parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False) parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[]) -logging.basicConfig(level=logging.INFO) - def disco_main(run=False): """ @@ -62,6 +60,8 @@ def disco_main(run=False): AutoSharder(config).run() return + logging.basicConfig(level=logging.INFO) + client = Client(config) bot = None diff --git a/disco/client.py b/disco/client.py index 496aa7a..0c15c50 100644 --- a/disco/client.py +++ b/disco/client.py @@ -82,7 +82,7 @@ class Client(object): self.events = Emitter(gevent.spawn) self.packets = Emitter(gevent.spawn) - self.api = APIClient(self) + self.api = APIClient(self.config.token, self) self.gw = GatewayClient(self, self.config.encoder) self.state = State(self, StateConfig(self.config.get('state', {}))) diff --git a/disco/gateway/ipc/gipc.py b/disco/gateway/ipc/gipc.py index b9468d0..4fdd49e 100644 --- a/disco/gateway/ipc/gipc.py +++ b/disco/gateway/ipc/gipc.py @@ -1,50 +1,92 @@ import random -import gipc import gevent import string import weakref +import marshal +import types + +from holster.enum import Enum + +from disco.util.logging import LoggingClass def get_random_str(size): - return ''.join([random.choice(string.ascii_printable) for _ in range(size)]) + return ''.join([random.choice(string.printable) for _ in range(size)]) + + +IPCMessageType = Enum( + 'CALL_FUNC', + 'GET_ATTR', + 'EXECUTE', + 'RESPONSE', +) -class GIPCProxy(object): - def __init__(self, pipe): +class GIPCProxy(LoggingClass): + def __init__(self, obj, pipe): + super(GIPCProxy, self).__init__() + self.obj = obj self.pipe = pipe self.results = weakref.WeakValueDictionary() gevent.spawn(self.read_loop) - def read_loop(self): - while True: - nonce, data = self.pipe.get() - if nonce in self.results: - self.results[nonce].set(data) + def resolve(self, parts): + base = self.obj + for part in parts: + base = getattr(base, part) - def __getattr__(self, name): - def wrapper(*args, **kwargs): - nonce = get_random_str() - self.results[nonce] = gevent.event.AsyncResult() - self.pipe.put(nonce, name, args, kwargs) - return self.results[nonce] - return wrapper + return base + def send(self, typ, data): + self.pipe.put((typ.value, data)) -class GIPCObject(object): - def __init__(self, inst, pipe): - self.inst = inst - self.pipe = pipe - gevent.spawn(self.read_loop) + def handle(self, mtype, data): + if mtype == IPCMessageType.CALL_FUNC: + nonce, func, args, kwargs = data + res = self.resolve(func)(*args, **kwargs) + self.send(IPCMessageType.RESPONSE, (nonce, res)) + elif mtype == IPCMessageType.GET_ATTR: + nonce, path = data + self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path))) + elif mtype == IPCMessageType.EXECUTE: + nonce, raw = data + func = types.FunctionType(marshal.loads(raw), globals(), nonce) + try: + result = func(self.obj) + except Exception as e: + self.log.exception('Failed to EXECUTE: ') + result = None + + self.send(IPCMessageType.RESPONSE, (nonce, result)) + elif mtype == IPCMessageType.RESPONSE: + nonce, res = data + if nonce in self.results: + self.results[nonce].set(res) def read_loop(self): while True: - nonce, func, args, kwargs = self.pipe.get() - func = getattr(self.inst, func) - self.pipe.put((nonce, func(*args, **kwargs))) + mtype, data = self.pipe.get() + + try: + self.handle(mtype, data) + except: + self.log.exception('Error in GIPCProxy:') + + def execute(self, func): + nonce = get_random_str(32) + raw = marshal.dumps(func.func_code) + self.results[nonce] = result = gevent.event.AsyncResult() + self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw))) + return result -class IPC(object): - def __init__(self, sharder): - self.sharder = sharder + def get(self, path): + nonce = get_random_str(32) + self.results[nonce] = result = gevent.event.AsyncResult() + self.pipe.put((IPCMessageType.GET_ATTR.value, (nonce, path))) + return result - def get_shards(self): - return {} + def call(self, path, *args, **kwargs): + nonce = get_random_str(32) + self.results[nonce] = result = gevent.event.AsyncResult() + self.pipe.put((IPCMessageType.CALL_FUNC.value, (nonce, path, args, kwargs))) + return result diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index fb89875..987ca73 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -1,16 +1,46 @@ +from __future__ import absolute_import + import gipc +import gevent +import types +import marshal from disco.client import Client from disco.bot import Bot, BotConfig from disco.api.client import APIClient -from disco.gateway.ipc.gipc import GIPCObject, GIPCProxy +from disco.gateway.ipc.gipc import GIPCProxy + + +def run_on(id, proxy): + def f(func): + return proxy.call(('run_on', ), id, marshal.dumps(func.func_code)) + return f + + +def run_self(bot): + def f(func): + result = gevent.event.AsyncResult() + result.set(func(bot)) + return result + return f def run_shard(config, id, pipe): + import logging + logging.basicConfig( + level=logging.INFO, + format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) + ) + config.shard_id = id client = Client(config) bot = Bot(client, BotConfig(config.bot)) - GIPCObject(bot, pipe) + bot.sharder = GIPCProxy(bot, pipe) + bot.shards = { + i: run_on(i, bot.sharder) for i in range(config.shard_count) + if i != id + } + bot.shards[id] = run_self(bot) bot.run_forever() @@ -20,12 +50,22 @@ class AutoSharder(object): self.client = APIClient(config.token) self.shards = {} self.config.shard_count = self.client.gateway_bot_get()['shards'] + self.config.shard_count = 10 + self.test = 1 + + def run_on(self, id, funccode): + func = types.FunctionType(marshal.loads(funccode), globals(), '_run_on_temp') + return self.shards[id].execute(func).wait(timeout=15) def run(self): - for shard in range(self.shard_count): + for shard in range(self.config.shard_count): + if self.config.manhole_enable and shard != 0: + self.config.manhole_enable = False + self.start_shard(shard) + gevent.sleep(6) def start_shard(self, id): cpipe, ppipe = gipc.pipe(duplex=True) gipc.start_process(run_shard, (self.config, id, cpipe)) - self.shards[id] = GIPCProxy(ppipe) + self.shards[id] = GIPCProxy(self, ppipe) diff --git a/disco/types/message.py b/disco/types/message.py index 3e9959f..15f751d 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -300,3 +300,51 @@ class Message(SlottedModel): return user_replace(self.mentions.get(id)) return re.sub('<@!?([0-9]+)>', replace, self.content) + + +class MessageTable(object): + def __init__(self, sep=' | ', codeblock=True, header_break=True): + self.header = [] + self.entries = [] + self.size_index = {} + self.sep = sep + self.codeblock = codeblock + self.header_break = header_break + + def recalculate_size_index(self, cols): + for idx, col in enumerate(cols): + if idx not in self.size_index or len(col) > self.size_index[idx]: + self.size_index[idx] = len(col) + + def set_header(self, *args): + self.header = args + self.recalculate_size_index(args) + + def add(self, *args): + args = list(map(str, args)) + self.entries.append(args) + self.recalculate_size_index(args) + + def compile_one(self, cols): + data = self.sep.lstrip() + + for idx, col in enumerate(cols): + padding = ' ' * ((self.size_index[idx] - len(col))) + data += col + padding + self.sep + + return data.rstrip() + + def compile(self): + data = [] + data.append(self.compile_one(self.header)) + + if self.header_break: + data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1)) + + for row in self.entries: + data.append(self.compile_one(row)) + + if self.codeblock: + return '```' + '\n'.join(data) + '```' + + return '\n'.join(data) From 77970c9bef408b536b9a36a6f7847b4ed73a6721 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 18 Oct 2016 17:32:03 -0500 Subject: [PATCH 04/91] Refactor plugin management a bit, etc fixes/improvements --- disco/bot/bot.py | 14 ++++++---- disco/bot/plugin.py | 58 ++++++++++++++++++++++++--------------- disco/cli.py | 2 ++ disco/gateway/ipc/gipc.py | 9 +++--- disco/gateway/sharder.py | 25 +++++++++++------ disco/state.py | 3 ++ 6 files changed, 69 insertions(+), 42 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 99ac0b5..70abe0b 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -262,7 +262,7 @@ class Bot(object): content = content.replace('@everyone', '', 1) else: for role in mention_roles: - content = content.replace(role.mention, '', 1) + content = content.replace('<@{}>'.format(role), '', 1) content = content.lstrip() @@ -356,7 +356,7 @@ class Bot(object): self.last_message_cache[msg.channel_id] = (msg, triggered) - def add_plugin(self, cls, config=None): + def add_plugin(self, cls, config=None, ctx=None): """ Adds and loads a plugin, based on its class. @@ -377,7 +377,7 @@ class Bot(object): config = self.load_plugin_config(cls) self.plugins[cls.__name__] = cls(self, config) - self.plugins[cls.__name__].load() + self.plugins[cls.__name__].load(ctx or {}) self.recompute() def rmv_plugin(self, cls): @@ -392,9 +392,11 @@ class Bot(object): if cls.__name__ not in self.plugins: raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__)) - self.plugins[cls.__name__].unload() + ctx = {} + self.plugins[cls.__name__].unload(ctx) del self.plugins[cls.__name__] self.recompute() + return ctx def reload_plugin(self, cls): """ @@ -402,9 +404,9 @@ class Bot(object): """ config = self.plugins[cls.__name__].config - self.rmv_plugin(cls) + ctx = self.rmv_plugin(cls) module = reload_module(inspect.getmodule(cls)) - self.add_plugin(getattr(module, cls.__name__), config) + self.add_plugin(getattr(module, cls.__name__), config, ctx) def run_forever(self): """ diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 5ee6803..8451b82 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,4 +1,5 @@ import six +import types import gevent import inspect import weakref @@ -18,8 +19,8 @@ class PluginDeco(object): Prio = Priority # TODO: dont smash class methods - @staticmethod - def add_meta_deco(meta): + @classmethod + def add_meta_deco(cls, meta): def deco(f): if not hasattr(f, 'meta'): f.meta = [] @@ -153,6 +154,20 @@ class Plugin(LoggingClass, PluginDeco): self.storage = bot.storage self.config = config + # This is an array of all meta functions we sniff at init + self.meta_funcs = [] + + for name, member in inspect.getmembers(self, predicate=inspect.ismethod): + if hasattr(member, 'meta'): + self.meta_funcs.append(member) + + # Unsmash local functions + if hasattr(Plugin, name): + method = types.MethodType(getattr(Plugin, name), self, self.__class__) + setattr(self, name, method) + + self.bind_all() + @property def name(self): return self.__class__.__name__ @@ -166,23 +181,21 @@ class Plugin(LoggingClass, PluginDeco): self._pre = {'command': [], 'listener': []} self._post = {'command': [], 'listener': []} - # TODO: when handling events/commands we need to track the greenlet in - # the greenlets set so we can termiante long running commands/listeners - # on reload. - - for name, member in inspect.getmembers(self, predicate=inspect.ismethod): - if hasattr(member, 'meta'): - for meta in member.meta: - if meta['type'] == 'listener': - self.register_listener(member, meta['what'], meta['desc'], meta['priority']) - elif meta['type'] == 'command': - meta['kwargs']['update'] = True - self.register_command(member, *meta['args'], **meta['kwargs']) - elif meta['type'] == 'schedule': - self.register_schedule(member, *meta['args'], **meta['kwargs']) - elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): - when, typ = meta['type'].split('_', 1) - self.register_trigger(typ, when, member) + for member in self.meta_funcs: + for meta in member.meta: + self.bind_meta(member, meta) + + def bind_meta(self, member, meta): + if meta['type'] == 'listener': + self.register_listener(member, meta['what'], meta['desc'], meta['priority']) + elif meta['type'] == 'command': + meta['kwargs']['update'] = True + self.register_command(member, *meta['args'], **meta['kwargs']) + elif meta['type'] == 'schedule': + self.register_schedule(member, *meta['args'], **meta['kwargs']) + elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): + when, typ = meta['type'].split('_', 1) + self.register_trigger(typ, when, member) def spawn(self, method, *args, **kwargs): obj = gevent.spawn(method, *args, **kwargs) @@ -208,6 +221,7 @@ class Plugin(LoggingClass, PluginDeco): getattr(self, '_' + when)[typ].append(func) def _dispatch(self, typ, func, event, *args, **kwargs): + self.greenlets.add(gevent.getcurrent()) self.ctx['plugin'] = self if hasattr(event, 'guild'): @@ -302,13 +316,13 @@ class Plugin(LoggingClass, PluginDeco): self.schedules[func.__name__] = self.spawn(repeat) - def load(self): + def load(self, ctx): """ Called when the plugin is loaded """ - self.bind_all() + pass - def unload(self): + def unload(self, ctx): """ Called when the plugin is unloaded """ diff --git a/disco/cli.py b/disco/cli.py index f2ef504..487de95 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -42,6 +42,7 @@ def disco_main(run=False): from disco.bot import Bot, BotConfig from disco.gateway.sharder import AutoSharder from disco.util.token import is_valid_token + from holster.log import set_logging_levels if os.path.exists(args.config): config = ClientConfig.from_file(args.config) @@ -61,6 +62,7 @@ def disco_main(run=False): return logging.basicConfig(level=logging.INFO) + set_logging_levels() client = Client(config) diff --git a/disco/gateway/ipc/gipc.py b/disco/gateway/ipc/gipc.py index 4fdd49e..e23206b 100644 --- a/disco/gateway/ipc/gipc.py +++ b/disco/gateway/ipc/gipc.py @@ -2,8 +2,7 @@ import random import gevent import string import weakref -import marshal -import types +import dill from holster.enum import Enum @@ -50,10 +49,10 @@ class GIPCProxy(LoggingClass): self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path))) elif mtype == IPCMessageType.EXECUTE: nonce, raw = data - func = types.FunctionType(marshal.loads(raw), globals(), nonce) + func = dill.loads(raw) try: result = func(self.obj) - except Exception as e: + except Exception: self.log.exception('Failed to EXECUTE: ') result = None @@ -74,7 +73,7 @@ class GIPCProxy(LoggingClass): def execute(self, func): nonce = get_random_str(32) - raw = marshal.dumps(func.func_code) + raw = dill.dumps(func) self.results[nonce] = result = gevent.event.AsyncResult() self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw))) return result diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index 987ca73..cb387a4 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -2,8 +2,10 @@ from __future__ import absolute_import import gipc import gevent -import types -import marshal +import logging +import dill + +from holster.log import set_logging_levels from disco.client import Client from disco.bot import Bot, BotConfig @@ -13,7 +15,7 @@ from disco.gateway.ipc.gipc import GIPCProxy def run_on(id, proxy): def f(func): - return proxy.call(('run_on', ), id, marshal.dumps(func.func_code)) + return proxy.call(('run_on', ), id, dill.dumps(func)) return f @@ -26,11 +28,11 @@ def run_self(bot): def run_shard(config, id, pipe): - import logging logging.basicConfig( level=logging.INFO, format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) ) + set_logging_levels() config.shard_id = id client = Client(config) @@ -50,11 +52,11 @@ class AutoSharder(object): self.client = APIClient(config.token) self.shards = {} self.config.shard_count = self.client.gateway_bot_get()['shards'] - self.config.shard_count = 10 - self.test = 1 + if self.config.shard_count > 1: + self.config.shard_count = 10 - def run_on(self, id, funccode): - func = types.FunctionType(marshal.loads(funccode), globals(), '_run_on_temp') + def run_on(self, id, raw): + func = dill.loads(raw) return self.shards[id].execute(func).wait(timeout=15) def run(self): @@ -65,7 +67,12 @@ class AutoSharder(object): self.start_shard(shard) gevent.sleep(6) + logging.basicConfig( + level=logging.INFO, + format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) + ) + def start_shard(self, id): - cpipe, ppipe = gipc.pipe(duplex=True) + cpipe, ppipe = gipc.pipe(duplex=True, encoder=dill.dumps, decoder=dill.loads) gipc.start_process(run_shard, (self.config, id, cpipe)) self.shards[id] = GIPCProxy(self, ppipe) diff --git a/disco/state.py b/disco/state.py index ae50698..c0e5338 100644 --- a/disco/state.py +++ b/disco/state.py @@ -243,6 +243,9 @@ class State(object): if event.member.guild_id not in self.guilds: return + if event.member.id not in self.guilds[event.member.guild_id].members: + return + self.guilds[event.member.guild_id].members[event.member.id].update(event.member) def on_guild_member_remove(self, event): From 417bf33d562f46d8867c16450837c3f9c2244d59 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 18 Oct 2016 20:49:23 -0500 Subject: [PATCH 05/91] Better autosharding interface, more fixes --- disco/bot/bot.py | 1 + disco/bot/command.py | 5 ++-- disco/bot/plugin.py | 6 +++- disco/gateway/sharder.py | 62 +++++++++++++++++++++++++++++++++++----- disco/types/guild.py | 1 + 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 70abe0b..2850f76 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -113,6 +113,7 @@ class Bot(object): def __init__(self, client, config=None): self.client = client self.config = config or BotConfig() + self.shards = {} # The context carries information about events in a threadlocal storage self.ctx = ThreadLocal() diff --git a/disco/bot/command.py b/disco/bot/command.py index 2546c9d..f2efbb1 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -6,7 +6,7 @@ from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property REGEX_FMT = '({})' -ARGS_REGEX = '( (.*)$|$)' +ARGS_REGEX = '( ((?:\n|.)*)$|$)' MENTION_RE = re.compile('<@!?([0-9]+)>') CommandLevels = Enum( @@ -109,7 +109,7 @@ class Command(object): self.triggers = [trigger] self.update(*args, **kwargs) - def update(self, args=None, level=None, aliases=None, group=None, is_regex=None): + def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False): self.triggers += aliases or [] def resolve_role(ctx, id): @@ -127,6 +127,7 @@ class Command(object): self.level = level self.group = group self.is_regex = is_regex + self.oob = oob @staticmethod def mention_type(getters, force=False): diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 8451b82..a6ffabe 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -206,6 +206,8 @@ class Plugin(LoggingClass, PluginDeco): """ Executes a CommandEvent this plugin owns """ + if not event.command.oob: + self.greenlets.add(gevent.getcurrent()) try: return event.command.execute(event) except CommandError as e: @@ -221,7 +223,9 @@ class Plugin(LoggingClass, PluginDeco): getattr(self, '_' + when)[typ].append(func) def _dispatch(self, typ, func, event, *args, **kwargs): - self.greenlets.add(gevent.getcurrent()) + # TODO: this is ugly + if typ != 'command': + self.greenlets.add(gevent.getcurrent()) self.ctx['plugin'] = self if hasattr(event, 'guild'): diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index cb387a4..b77d9b9 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -1,9 +1,11 @@ from __future__ import absolute_import +import six import gipc import gevent import logging import dill +import types from holster.log import set_logging_levels @@ -11,11 +13,34 @@ from disco.client import Client from disco.bot import Bot, BotConfig from disco.api.client import APIClient from disco.gateway.ipc.gipc import GIPCProxy +from disco.util.snowflake import calculate_shard + + +def dump_function(func): + if six.PY3: + return dill.dumps(( + func.__code__, + func.__name__, + func.__defaults__, + func.__closure__, + )) + else: + return dill.dumps(( + func.func_code, + func.func_name, + func.func_defaults, + func.func_closure + )) + + +def load_function(func): + code, name, defaults, closure = dill.loads(func) + return types.FunctionType(code, globals(), name, defaults, closure) def run_on(id, proxy): def f(func): - return proxy.call(('run_on', ), id, dill.dumps(func)) + return proxy.call(('run_on', ), id, dump_function(func)) return f @@ -38,14 +63,36 @@ def run_shard(config, id, pipe): client = Client(config) bot = Bot(client, BotConfig(config.bot)) bot.sharder = GIPCProxy(bot, pipe) - bot.shards = { - i: run_on(i, bot.sharder) for i in range(config.shard_count) - if i != id - } - bot.shards[id] = run_self(bot) + bot.shards = ShardHelper(config.shard_count, bot) bot.run_forever() +class ShardHelper(object): + def __init__(self, count, bot): + self.count = count + self.bot = bot + + def keys(self): + for id in xrange(self.count): + yield id + + def on(self, id, func): + if id == self.bot.client.config.shard_id: + result = gevent.event.AsyncResult() + result.set(func(self.bot)) + return result + + return self.bot.sharder.call(('run_on', ), id, dump_function(func)) + + def all(self, func, timeout=None): + pool = gevent.pool.Pool(self.count) + return dict(zip(range(self.count), pool.imap(lambda i: self.on(i, func).wait(timeout=timeout), range(self.count)))) + + def for_id(self, id, func): + shard = calculate_shard(self.count, id) + return self.on(shard, func) + + class AutoSharder(object): def __init__(self, config): self.config = config @@ -56,7 +103,8 @@ class AutoSharder(object): self.config.shard_count = 10 def run_on(self, id, raw): - func = dill.loads(raw) + func = load_function(raw) + # func = dill.loads(raw) return self.shards[id].execute(func).wait(timeout=15) def run(self): diff --git a/disco/types/guild.py b/disco/types/guild.py index 9708e3e..60b3c23 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -245,6 +245,7 @@ class Guild(SlottedModel, Permissible): roles = Field(dictof(Role, key='id')) emojis = Field(dictof(Emoji, key='id')) voice_states = Field(dictof(VoiceState, key='session_id')) + member_count = Field(int) synced = Field(bool, default=False) From 13ee463dbb4fdcfa1543cdfc6e363d68a4ebf58d Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 19 Oct 2016 17:06:55 -0500 Subject: [PATCH 06/91] Cleanup and fixes --- disco/api/client.py | 8 ++++++++ disco/api/http.py | 8 +++++++- disco/cli.py | 2 +- disco/types/channel.py | 20 ++++++++++++++++---- disco/types/guild.py | 3 ++- disco/util/hashmap.py | 4 ++-- 6 files changed, 36 insertions(+), 9 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index 6705f16..afd360b 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -25,6 +25,14 @@ class APIClient(LoggingClass): """ An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits the models with the returned data. + + Args + ---- + token : str + The Discord authentication token (without prefixes) to be used for all + HTTP requests. + client : :class:`disco.client.Client` + The base disco client which will be used when constructing models. """ def __init__(self, token, client=None): super(APIClient, self).__init__() diff --git a/disco/api/http.py b/disco/api/http.py index 4757da4..982668e 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -113,6 +113,12 @@ class APIException(Exception): """ Exception thrown when an HTTP-client level error occurs. Usually this will be a non-success status-code, or a transient network issue. + + Attributes + ---------- + status_code : int + The status code returned by the API for the request that triggered this + error. """ def __init__(self, msg, status_code=0, content=None): self.status_code = status_code @@ -200,7 +206,7 @@ class HTTPClient(LoggingClass): # If we got a success status code, just return the data if r.status_code < 400: return r - elif r.status_code != 429 and 400 < r.status_code < 500: + elif r.status_code != 429 and 400 <= r.status_code < 500: raise APIException('Request failed', r.status_code, r.content) else: if r.status_code == 429: diff --git a/disco/cli.py b/disco/cli.py index 487de95..01c4aaf 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -40,7 +40,6 @@ def disco_main(run=False): from disco.client import Client, ClientConfig from disco.bot import Bot, BotConfig - from disco.gateway.sharder import AutoSharder from disco.util.token import is_valid_token from holster.log import set_logging_levels @@ -58,6 +57,7 @@ def disco_main(run=False): return if args.shard_auto: + from disco.gateway.sharder import AutoSharder AutoSharder(config).run() return diff --git a/disco/types/channel.py b/disco/types/channel.py index f824a35..4c2a54c 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -280,12 +280,24 @@ class Channel(SlottedModel, Permissible): if not messages: return - if len(messages) <= 2: - for msg in messages: - self.delete_message(msg) - else: + if self.can(self.client.state.me, Permissions.MANAGE_MESSAGES) and len(messages) > 2: for chunk in chunks(messages, 100): self.client.api.channels_messages_delete_bulk(self.id, chunk) + else: + for msg in messages: + self.delete_message(msg) + + def delete(self): + assert (self.is_dm or self.guild.can(self.client.state.me, Permissions.MANAGE_GUILD)), 'Invalid Permissions' + self.client.api.channels_delete(self.id) + + def close(self): + """ + Closes a DM channel. This is intended as a safer version of `delete`, + enforcing that the channel is actually a DM. + """ + assert self.is_dm, 'Cannot close non-DM channel' + self.delete() class MessageIterator(object): diff --git a/disco/types/guild.py b/disco/types/guild.py index 60b3c23..f6e1212 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -7,7 +7,7 @@ from disco.api.http import APIException from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum -from disco.types.user import User +from disco.types.user import User, Presence from disco.types.voice import VoiceState from disco.types.channel import Channel from disco.types.permissions import PermissionValue, Permissions, Permissible @@ -246,6 +246,7 @@ class Guild(SlottedModel, Permissible): emojis = Field(dictof(Emoji, key='id')) voice_states = Field(dictof(VoiceState, key='session_id')) member_count = Field(int) + presences = Field(listof(Presence)) synced = Field(bool, default=False) diff --git a/disco/util/hashmap.py b/disco/util/hashmap.py index ef32647..50cf6f4 100644 --- a/disco/util/hashmap.py +++ b/disco/util/hashmap.py @@ -45,12 +45,12 @@ class HashMap(UserDict): def filter(self, predicate): if not callable(predicate): raise TypeError('predicate must be callable') - return filter(self.values(), predicate) + return filter(predicate, self.values()) def map(self, predicate): if not callable(predicate): raise TypeError('predicate must be callable') - return map(self.values(), predicate) + return map(predicate, self.values()) class DefaultHashMap(defaultdict, HashMap): From c7ff3b3d1808172fd0b840db3845a7e649de5728 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 19 Oct 2016 19:24:27 -0500 Subject: [PATCH 07/91] Move IPC code, bit more cleanup --- disco/bot/bot.py | 5 ++++- disco/gateway/{ipc/gipc.py => ipc.py} | 0 disco/gateway/ipc/__init__.py | 0 disco/gateway/packets.py | 4 ++-- disco/gateway/sharder.py | 2 +- 5 files changed, 7 insertions(+), 4 deletions(-) rename disco/gateway/{ipc/gipc.py => ipc.py} (100%) delete mode 100644 disco/gateway/ipc/__init__.py diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 2850f76..c4f391b 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -113,7 +113,9 @@ class Bot(object): def __init__(self, client, config=None): self.client = client self.config = config or BotConfig() - self.shards = {} + + # Shard manager + self.shards = None # The context carries information about events in a threadlocal storage self.ctx = ThreadLocal() @@ -123,6 +125,7 @@ class Bot(object): if self.config.storage_enabled: self.storage = Storage(self.ctx, self.config.from_prefix('storage')) + # If the manhole is enabled, add this bot as a local if self.client.config.manhole_enable: self.client.manhole_locals['bot'] = self diff --git a/disco/gateway/ipc/gipc.py b/disco/gateway/ipc.py similarity index 100% rename from disco/gateway/ipc/gipc.py rename to disco/gateway/ipc.py diff --git a/disco/gateway/ipc/__init__.py b/disco/gateway/ipc/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/disco/gateway/packets.py b/disco/gateway/packets.py index e78c1ce..a15bfd8 100644 --- a/disco/gateway/packets.py +++ b/disco/gateway/packets.py @@ -1,7 +1,7 @@ from holster.enum import Enum -SEND = object() -RECV = object() +SEND = 1 +RECV = 2 OPCode = Enum( DISPATCH=0, diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index b77d9b9..f718c3d 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -12,7 +12,7 @@ from holster.log import set_logging_levels from disco.client import Client from disco.bot import Bot, BotConfig from disco.api.client import APIClient -from disco.gateway.ipc.gipc import GIPCProxy +from disco.gateway.ipc import GIPCProxy from disco.util.snowflake import calculate_shard From fb2a19efb8752c622b937f97a0db11129dfce6e1 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 19 Oct 2016 19:43:36 -0500 Subject: [PATCH 08/91] Remove dill requirement, switch to marshal for serialization, etc --- README.md | 1 + disco/gateway/ipc.py | 6 +++--- disco/gateway/sharder.py | 44 +++------------------------------------- disco/util/serializer.py | 35 ++++++++++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 027229d..6c52569 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Disco was built to run both as a generic-use library, and a standalone bot toolk |requests[security]|adds packages for a proper SSL implementation| |ujson|faster json parser, improves performance| |erlpack|ETF parser, only Python 2.x, run with the --encoder=etf flag| +|gipc|Gevent IPC, required for autosharding| ## Examples diff --git a/disco/gateway/ipc.py b/disco/gateway/ipc.py index e23206b..bcd3383 100644 --- a/disco/gateway/ipc.py +++ b/disco/gateway/ipc.py @@ -2,11 +2,11 @@ import random import gevent import string import weakref -import dill from holster.enum import Enum from disco.util.logging import LoggingClass +from disco.util.serializer import dump_function, load_function def get_random_str(size): @@ -49,7 +49,7 @@ class GIPCProxy(LoggingClass): self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path))) elif mtype == IPCMessageType.EXECUTE: nonce, raw = data - func = dill.loads(raw) + func = load_function(raw) try: result = func(self.obj) except Exception: @@ -73,7 +73,7 @@ class GIPCProxy(LoggingClass): def execute(self, func): nonce = get_random_str(32) - raw = dill.dumps(func) + raw = dump_function(func) self.results[nonce] = result = gevent.event.AsyncResult() self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw))) return result diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index f718c3d..5d98ad6 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -1,11 +1,9 @@ from __future__ import absolute_import -import six import gipc import gevent import logging -import dill -import types +import marshal from holster.log import set_logging_levels @@ -14,42 +12,7 @@ from disco.bot import Bot, BotConfig from disco.api.client import APIClient from disco.gateway.ipc import GIPCProxy from disco.util.snowflake import calculate_shard - - -def dump_function(func): - if six.PY3: - return dill.dumps(( - func.__code__, - func.__name__, - func.__defaults__, - func.__closure__, - )) - else: - return dill.dumps(( - func.func_code, - func.func_name, - func.func_defaults, - func.func_closure - )) - - -def load_function(func): - code, name, defaults, closure = dill.loads(func) - return types.FunctionType(code, globals(), name, defaults, closure) - - -def run_on(id, proxy): - def f(func): - return proxy.call(('run_on', ), id, dump_function(func)) - return f - - -def run_self(bot): - def f(func): - result = gevent.event.AsyncResult() - result.set(func(bot)) - return result - return f +from disco.util.serializer import dump_function, load_function def run_shard(config, id, pipe): @@ -104,7 +67,6 @@ class AutoSharder(object): def run_on(self, id, raw): func = load_function(raw) - # func = dill.loads(raw) return self.shards[id].execute(func).wait(timeout=15) def run(self): @@ -121,6 +83,6 @@ class AutoSharder(object): ) def start_shard(self, id): - cpipe, ppipe = gipc.pipe(duplex=True, encoder=dill.dumps, decoder=dill.loads) + cpipe, ppipe = gipc.pipe(duplex=True, encoder=marshal.dumps, decoder=marshal.loads) gipc.start_process(run_shard, (self.config, id, cpipe)) self.shards[id] = GIPCProxy(self, ppipe) diff --git a/disco/util/serializer.py b/disco/util/serializer.py index 74fe766..e248490 100644 --- a/disco/util/serializer.py +++ b/disco/util/serializer.py @@ -1,3 +1,5 @@ +import six +import types class Serializer(object): @@ -36,3 +38,36 @@ class Serializer(object): def dumps(cls, fmt, raw): _, dumps = getattr(cls, fmt)() return dumps(raw) + + +def dump_cell(cell): + return cell.cell_contents + + +def load_cell(cell): + if six.PY3: + return (lambda y: cell).__closure__[0] + else: + return (lambda y: cell).func_closure[0] + + +def dump_function(func): + if six.PY3: + return ( + func.__code__, + func.__name__, + func.__defaults__, + list(map(dump_cell, func.__closure__)) if func.__closure__ else [], + ) + else: + return ( + func.func_code, + func.func_name, + func.func_defaults, + list(map(dump_cell, func.func_closure)) if func.func_closure else [], + ) + + +def load_function((code, name, defaults, closure)): + closure = tuple(map(load_cell, closure)) + return types.FunctionType(code, globals(), name, defaults, closure) From 0e3f518ad1f58329ed20fc983a5fa8d5691bc90a Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 19 Oct 2016 20:15:00 -0500 Subject: [PATCH 09/91] Add Client.update_presence, fix heartbeating on dead connections --- disco/client.py | 22 ++++++++++++++++++++++ disco/gateway/client.py | 6 ++++++ disco/types/base.py | 15 +++++++++++++-- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/disco/client.py b/disco/client.py index 0c15c50..d447ecc 100644 --- a/disco/client.py +++ b/disco/client.py @@ -1,3 +1,4 @@ +import time import gevent from holster.emitter import Emitter @@ -5,6 +6,8 @@ from holster.emitter import Emitter from disco.state import State, StateConfig from disco.api.client import APIClient from disco.gateway.client import GatewayClient +from disco.gateway.packets import OPCode +from disco.types.user import Status, Game from disco.util.config import Config from disco.util.logging import LoggingClass from disco.util.backdoor import DiscoBackdoorServer @@ -99,6 +102,25 @@ class Client(object): localf=lambda: self.manhole_locals) self.manhole.start() + def update_presence(self, game=None, status=None, afk=False, since=0.0): + if game and not isinstance(game, Game): + raise TypeError('Game must be a Game model') + + if status is Status.IDLE and not since: + since = int(time.time() * 1000) + + payload = { + 'afk': afk, + 'since': since, + 'status': status.value, + 'game': None, + } + + if game: + payload['game'] = game.to_dict() + + self.gw.send(OPCode.STATUS_UPDATE, payload) + def run(self): """ Run the client (e.g. the :class:`GatewayClient`) in a new greenlet diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 61f7dfd..251b800 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -173,10 +173,16 @@ class GatewayClient(LoggingClass): }) def on_close(self, code, reason): + # Kill heartbeater, a reconnect/resume will trigger a HELLO which will + # respawn it + self._heartbeat_task.kill() + + # If we're quitting, just break out of here if self.shutting_down: self.log.info('WS Closed: shutting down') return + # Track reconnect attempts self.reconnects += 1 self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) diff --git a/disco/types/base.py b/disco/types/base.py index 63bc628..78536a3 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -3,7 +3,7 @@ import gevent import inspect import functools -from holster.enum import BaseEnumMeta +from holster.enum import BaseEnumMeta, EnumAttr from datetime import datetime as real_datetime from disco.util.functional import CachedSlotProperty @@ -33,6 +33,14 @@ class FieldType(object): else: self.typ = lambda raw, _: typ(raw) + def serialize(self, value): + if isinstance(value, EnumAttr): + return value.value + elif isinstance(value, Model): + return value.to_dict() + else: + return value + def try_convert(self, raw, client): pass @@ -260,7 +268,10 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): pass def to_dict(self): - return {k: getattr(self, k) for k in six.iterkeys(self.__class__._fields)} + obj = {} + for name, field in six.iteritems(self.__class__._fields): + obj[name] = field.serialize(getattr(self, name)) + return obj @classmethod def create(cls, client, data, **kwargs): From e24243076b2394fdced401b7bd5521defea1ba4b Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 20 Oct 2016 02:23:47 -0500 Subject: [PATCH 10/91] Document gateway events, couple fixes --- disco/client.py | 2 +- disco/gateway/client.py | 4 +- disco/gateway/events.py | 228 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 215 insertions(+), 19 deletions(-) diff --git a/disco/client.py b/disco/client.py index d447ecc..66bcfb3 100644 --- a/disco/client.py +++ b/disco/client.py @@ -112,7 +112,7 @@ class Client(object): payload = { 'afk': afk, 'since': since, - 'status': status.value, + 'status': status.value.lower(), 'game': None, } diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 251b800..98d71d6 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -161,8 +161,8 @@ class GatewayClient(LoggingClass): 'compress': True, 'large_threshold': 250, 'shard': [ - self.client.config.shard_id, - self.client.config.shard_count, + int(self.client.config.shard_id), + int(self.client.config.shard_count), ], 'properties': { '$os': 'linux', diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 5b2e8b5..2129e53 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -7,7 +7,7 @@ from disco.types.user import User, Presence from disco.types.channel import Channel from disco.types.message import Message from disco.types.voice import VoiceState -from disco.types.guild import Guild, GuildMember, Role +from disco.types.guild import Guild, GuildMember, Role, Emoji from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime @@ -101,6 +101,19 @@ class Ready(GatewayEvent): """ Sent after the initial gateway handshake is complete. Contains data required for bootstrapping the client's states. + + Attributes + ----- + version : int + The gateway version. + session_id : str + The session ID. + user : :class:`disco.types.user.User` + The user object for the authed account. + guilds : list[:class:`disco.types.guild.Guild` + All guilds this account is a member of. These are shallow guild objects. + private_channels list[:class:`disco.types.channel.Channel`] + All private channels (DMs) open for this account. """ version = Field(int, alias='v') session_id = Field(str) @@ -119,31 +132,70 @@ class Resumed(GatewayEvent): @wraps_model(Guild) class GuildCreate(GatewayEvent): """ - Sent when a guild is created, or becomes available. + Sent when a guild is joined, or becomes available. + + Attributes + ----- + guild : :class:`disco.types.guild.Guild` + The guild being created (e.g. joined) + unavailable : bool + If false, this guild is coming online from a previously unavailable state, + and if None, this is a normal guild join event. """ unavailable = Field(bool) + @property + def created(self): + """ + Shortcut property which is true when we actually joined the guild. + """ + return self.unavailable is None + @wraps_model(Guild) class GuildUpdate(GatewayEvent): """ Sent when a guild is updated. + + Attributes + ----- + guild : :class:`disco.types.guild.Guild` + The updated guild object. """ - pass class GuildDelete(GatewayEvent): """ - Sent when a guild is deleted, or becomes unavailable. + Sent when a guild is deleted, left, or becomes unavailable. + + Attributes + ----- + id : snowflake + The ID of the guild being deleted. + unavailable : bool + If true, this guild is becoming unavailable, if None this is a normal + guild leave event. """ id = Field(snowflake) unavailable = Field(bool) + @property + def deleted(self): + """ + Shortcut property which is true when we actually have left the guild. + """ + return self.unavailable is None + @wraps_model(Channel) class ChannelCreate(GatewayEvent): """ Sent when a channel is created. + + Attributes + ----- + channel : :class:`disco.types.channel.Channel` + The channel which was created. """ @@ -151,21 +203,36 @@ class ChannelCreate(GatewayEvent): class ChannelUpdate(ChannelCreate): """ Sent when a channel is updated. + + Attributes + ----- + channel : :class:`disco.types.channel.Channel` + The channel which was updated. """ - pass @wraps_model(Channel) class ChannelDelete(ChannelCreate): """ Sent when a channel is deleted. + + Attributes + ----- + channel : :class:`disco.types.channel.Channel` + The channel being deleted. """ - pass class ChannelPinsUpdate(GatewayEvent): """ Sent when a channel's pins are updated. + + Attributes + ----- + channel_id : snowflake + ID of the channel where pins where updated. + last_pin_timestap : datetime + The time the last message was pinned. """ channel_id = Field(snowflake) last_pin_timestamp = Field(lazy_datetime) @@ -175,35 +242,68 @@ class ChannelPinsUpdate(GatewayEvent): class GuildBanAdd(GatewayEvent): """ Sent when a user is banned from a guild. + + Attributes + ----- + guild_id : snowflake + The ID of the guild the user is being banned from. + user : :class:`disco.types.user.User` + The user being banned from the guild. """ - pass + guild_id = Field(snowflake) @wraps_model(User) class GuildBanRemove(GuildBanAdd): """ Sent when a user is unbanned from a guild. + + Attributes + ----- + guild_id : snowflake + The ID of the guild the user is being unbanned from. + user : :class:`disco.types.user.User` + The user being unbanned from the guild. """ - pass class GuildEmojisUpdate(GatewayEvent): """ Sent when a guild's emojis are updated. + + Attributes + ----- + guild_id : snowflake + The ID of the guild the emojis are being updated in. + emojis : list[:class:`disco.types.guild.Emoji`] + The new set of emojis for the guild """ - pass + guild_id = Field(snowflake) + emojis = Field(listof(Emoji)) class GuildIntegrationsUpdate(GatewayEvent): """ Sent when a guild's integrations are updated. + + Attributes + ----- + guild_id : snowflake + The ID of the guild integrations where updated in. """ - pass + guild_id = Field(snowflake) class GuildMembersChunk(GatewayEvent): """ Sent in response to a member's chunk request. + + Attributes + ----- + guild_id : snowflake + The ID of the guild this member chunk is for. + members : list[:class:`disco.types.guild.GuildMember`] + The chunk of members. """ guild_id = Field(snowflake) members = Field(listof(GuildMember)) @@ -213,13 +313,24 @@ class GuildMembersChunk(GatewayEvent): class GuildMemberAdd(GatewayEvent): """ Sent when a user joins a guild. + + Attributes + ----- + member : :class:`disco.types.guild.GuildMember` + The member that has joined the guild. """ - pass class GuildMemberRemove(GatewayEvent): """ Sent when a user leaves a guild (via leaving, kicking, or banning). + + Attributes + ----- + guild_id : snowflake + The ID of the guild the member left from. + user : :class:`disco.types.user.User` + The user who was removed from the guild. """ guild_id = Field(snowflake) user = Field(User) @@ -229,13 +340,24 @@ class GuildMemberRemove(GatewayEvent): class GuildMemberUpdate(GatewayEvent): """ Sent when a guilds member is updated. + + Attributes + ----- + member : :class:`disco.types.guild.GuildMember` + The member being updated """ - pass class GuildRoleCreate(GatewayEvent): """ Sent when a role is created. + + Attributes + ----- + guild_id : snowflake + The ID of the guild where the role was created. + role : :class:`disco.types.guild.Role` + The role that was created. """ guild_id = Field(snowflake) role = Field(Role) @@ -244,13 +366,26 @@ class GuildRoleCreate(GatewayEvent): class GuildRoleUpdate(GuildRoleCreate): """ Sent when a role is updated. + + Attributes + ----- + guild_id : snowflake + The ID of the guild where the role was created. + role : :class:`disco.types.guild.Role` + The role that was created. """ - pass class GuildRoleDelete(GatewayEvent): """ Sent when a role is deleted. + + Attributes + ----- + guild_id : snowflake + The ID of the guild where the role is being deleted. + role_id : snowflake + The id of the role being deleted. """ guild_id = Field(snowflake) role_id = Field(snowflake) @@ -260,6 +395,11 @@ class GuildRoleDelete(GatewayEvent): class MessageCreate(GatewayEvent): """ Sent when a message is created. + + Attributes + ----- + message : :class:`disco.types.message.Message` + The message being created. """ @@ -267,13 +407,24 @@ class MessageCreate(GatewayEvent): class MessageUpdate(MessageCreate): """ Sent when a message is updated/edited. + + Attributes + ----- + message : :class:`disco.types.message.Message` + The message being updated. """ - pass class MessageDelete(GatewayEvent): """ Sent when a message is deleted. + + Attributes + ----- + id : snowflake + The ID of message being deleted. + channel_id : snowflake + The ID of the channel the message was deleted in. """ id = Field(snowflake) channel_id = Field(snowflake) @@ -282,6 +433,13 @@ class MessageDelete(GatewayEvent): class MessageDeleteBulk(GatewayEvent): """ Sent when multiple messages are deleted from a channel. + + Attributes + ----- + channel_id : snowflake + The channel the messages are being deleted in. + ids : list[snowflake] + List of messages being deleted in the channel. """ channel_id = Field(snowflake) ids = Field(listof(snowflake)) @@ -291,6 +449,15 @@ class MessageDeleteBulk(GatewayEvent): class PresenceUpdate(GatewayEvent): """ Sent when a user's presence is updated. + + Attributes + ----- + presence : :class:`disco.types.user.Presence` + The updated presence object. + guild_id : snowflake + The guild this presence update is for. + roles : list[snowflake] + List of roles the user from the presence is part of. """ guild_id = Field(snowflake) roles = Field(listof(snowflake)) @@ -299,23 +466,45 @@ class PresenceUpdate(GatewayEvent): class TypingStart(GatewayEvent): """ Sent when a user begins typing in a channel. + + Attributes + ----- + channel_id : snowflake + The ID of the channel where the user is typing. + user_id : snowflake + The ID of the user who is typing. + timestamp : datetime + When the user started typing. """ channel_id = Field(snowflake) user_id = Field(snowflake) - timestamp = Field(snowflake) + timestamp = Field(lazy_datetime) @wraps_model(VoiceState, alias='state') class VoiceStateUpdate(GatewayEvent): """ Sent when a users voice state changes. + + Attributes + ----- + state : :class:`disco.models.voice.VoiceState` + The voice state which was updated. """ - pass class VoiceServerUpdate(GatewayEvent): """ Sent when a voice server is updated. + + Attributes + ----- + token : str + The token for the voice server. + endpoint : str + The endpoint for the voice server. + guild_id : snowflake + The guild ID this voice server update is for. """ token = Field(str) endpoint = Field(str) @@ -325,6 +514,13 @@ class VoiceServerUpdate(GatewayEvent): class WebhooksUpdate(GatewayEvent): """ Sent when a channels webhooks are updated. + + Attributes + ----- + channel_id : snowflake + The channel ID this webhooks update is for. + guild_id : snowflake + The guild ID this webhooks update is for. """ channel_id = Field(snowflake) guild_id = Field(snowflake) From cc8fd6db84d5723c5fc0f91fc997c0ec2d9f2178 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 20 Oct 2016 03:16:51 -0500 Subject: [PATCH 11/91] Better listener support, another fix to VoiceStateUpdate --- disco/bot/plugin.py | 22 ++++++++++------------ disco/state.py | 8 ++++++++ requirements.txt | 2 +- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a6ffabe..40b6e49 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -41,27 +41,27 @@ class PluginDeco(object): return deco @classmethod - def listen(cls, event_name, priority=None): + def listen(cls, *args, **kwargs): """ Binds the function to listen for a given event name """ return cls.add_meta_deco({ 'type': 'listener', 'what': 'event', - 'desc': event_name, - 'priority': priority + 'args': args, + 'kwargs': kwargs, }) @classmethod - def listen_packet(cls, op, priority=None): + def listen_packet(cls, *args, **kwargs): """ Binds the function to listen for a given gateway op code """ return cls.add_meta_deco({ 'type': 'listener', 'what': 'packet', - 'desc': op, - 'priority': priority, + 'args': args, + 'kwargs': kwargs, }) @classmethod @@ -187,7 +187,7 @@ class Plugin(LoggingClass, PluginDeco): def bind_meta(self, member, meta): if meta['type'] == 'listener': - self.register_listener(member, meta['what'], meta['desc'], meta['priority']) + self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs']) elif meta['type'] == 'command': meta['kwargs']['update'] = True self.register_command(member, *meta['args'], **meta['kwargs']) @@ -248,7 +248,7 @@ class Plugin(LoggingClass, PluginDeco): return True - def register_listener(self, func, what, desc, priority): + def register_listener(self, func, what, desc, priority=Priority.NONE, conditional=None): """ Registers a listener @@ -265,12 +265,10 @@ class Plugin(LoggingClass, PluginDeco): """ func = functools.partial(self._dispatch, 'listener', func) - priority = priority or Priority.NONE - if what == 'event': - li = self.bot.client.events.on(desc, func, priority=priority) + li = self.bot.client.events.on(desc, func, priority=priority, conditional=conditional) elif what == 'packet': - li = self.bot.client.packets.on(desc, func, priority=priority) + li = self.bot.client.packets.on(desc, func, priority=priority, conditional=conditional) else: raise Exception('Invalid listener what: {}'.format(what)) diff --git a/disco/state.py b/disco/state.py index c0e5338..abe4cff 100644 --- a/disco/state.py +++ b/disco/state.py @@ -185,6 +185,9 @@ class State(object): for member in six.itervalues(event.guild.members): self.users[member.user.id] = member.user + for voice_state in six.itervalues(event.guild.voice_states): + self.voice_states[voice_state.session_id] = voice_state + if self.config.sync_guild_members: event.guild.sync() @@ -225,8 +228,13 @@ class State(object): guild.voice_states[event.state.session_id].update(event.state) else: del guild.voice_states[event.state.session_id] + + # Prevent a weird race where events come in before the guild_create (I think...) + if event.state.session_id in self.voice_states: + del self.voice_states[event.state.session_id] elif event.state.channel_id: guild.voice_states[event.state.session_id] = event.state + self.voice_states[event.state.session_id] = event.state def on_guild_member_add(self, event): if event.member.user.id not in self.users: diff --git a/requirements.txt b/requirements.txt index 9051886..9a1e8fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ gevent==1.1.2 -holster==1.0.6 +holster==1.0.8 inflection==0.3.1 requests==2.11.1 six==1.10.0 From 18fe47dbf50936c9193715c3b954f19a9ff250d3 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 20 Oct 2016 03:18:31 -0500 Subject: [PATCH 12/91] Fuck python3 --- disco/util/serializer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/disco/util/serializer.py b/disco/util/serializer.py index e248490..a481a9d 100644 --- a/disco/util/serializer.py +++ b/disco/util/serializer.py @@ -68,6 +68,7 @@ def dump_function(func): ) -def load_function((code, name, defaults, closure)): +def load_function(args): + code, name, defaults, closure = args closure = tuple(map(load_cell, closure)) return types.FunctionType(code, globals(), name, defaults, closure) From 67b6967ce987333aa77861cf274bd5cde3184a08 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 20 Oct 2016 03:23:22 -0500 Subject: [PATCH 13/91] Release 0.0.6 --- disco/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/__init__.py b/disco/__init__.py index 6e83b38..0cdff65 100644 --- a/disco/__init__.py +++ b/disco/__init__.py @@ -1 +1 @@ -VERSION = '0.0.5' +VERSION = '0.0.6' From fabab89e368d0d1cea08236813a97b8b6fc0e68b Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 27 Oct 2016 01:42:50 -0500 Subject: [PATCH 14/91] Message reactions support (untested) --- disco/api/client.py | 17 +++++++++++++ disco/api/http.py | 4 ++++ disco/bot/providers/redis.py | 16 +++++++------ disco/gateway/encoding/base.py | 4 +++- disco/gateway/events.py | 44 +++++++++++++++++++++++++++++++++- disco/types/message.py | 12 ++++++++++ 6 files changed, 88 insertions(+), 9 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index afd360b..869d49c 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -98,6 +98,23 @@ class APIClient(LoggingClass): def channels_messages_delete_bulk(self, channel, messages): self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages}) + def channels_messages_reactions_get(self, channel, message, emoji): + r = self.http(Routes.CHANNELS_MESSAGES_REACTIONS_GET, dict(channel=channel, message=message, emoji=emoji)) + return User.create_map(self.client, r.json()) + + def channels_messages_reactions_create(self, channel, message, emoji): + self.http(Routes.CHANNELS_MESSAGES_REACTIONS_CREATE, dict(channel=channel, message=message, emoji=emoji)) + + def channels_messages_reactions_delete(self, channel, message, emoji, user=None): + route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_ME + obj = dict(channel=channel, message=message, emoji=emoji) + + if user: + route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_USER + obj['user'] = user + + self.http(route, obj) + def channels_permissions_modify(self, channel, permission, allow, deny, typ): self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={ 'allow': allow, diff --git a/disco/api/http.py b/disco/api/http.py index 982668e..9f90649 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -39,6 +39,10 @@ class Routes(object): CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete') + CHANNELS_MESSAGES_REACTIONS_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}/reactions/{emoji}') + CHANNELS_MESSAGES_REACTIONS_CREATE = (HTTPMethod.PUT, CHANNELS + '/messages/{message}/reactions/{emoji}/@me') + CHANNELS_MESSAGES_REACTIONS_DELETE_ME = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/@me') + CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/{user}') CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}') CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}') CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites') diff --git a/disco/bot/providers/redis.py b/disco/bot/providers/redis.py index d0c2e5b..239ac9c 100644 --- a/disco/bot/providers/redis.py +++ b/disco/bot/providers/redis.py @@ -11,31 +11,33 @@ from .base import BaseProvider, SEP_SENTINEL class RedisProvider(BaseProvider): def __init__(self, config): self.config = config + self.format = config.get('format', 'pickle') def load(self): - self.redis = redis.Redis( + self.conn = redis.Redis( host=self.config.get('host', 'localhost'), port=self.config.get('port', 6379), db=self.config.get('db', 0)) def exists(self, key): - return self.db.exists(key) + return self.conn.exists(key) def keys(self, other): count = other.count(SEP_SENTINEL) + 1 - for key in self.db.scan_iter(u'{}*'.format(other)): + for key in self.conn.scan_iter(u'{}*'.format(other)): + key = key.decode('utf-8') if key.count(SEP_SENTINEL) == count: yield key def get_many(self, keys): - for key, value in izip(keys, self.db.mget(keys)): + for key, value in izip(keys, self.conn.mget(keys)): yield (key, Serializer.loads(self.format, value)) def get(self, key): - return Serializer.loads(self.format, self.db.get(key)) + return Serializer.loads(self.format, self.conn.get(key)) def set(self, key, value): - self.db.set(key, Serializer.dumps(self.format, value)) + self.conn.set(key, Serializer.dumps(self.format, value)) def delete(self, key, value): - self.db.delete(key) + self.conn.delete(key) diff --git a/disco/gateway/encoding/base.py b/disco/gateway/encoding/base.py index e663cf6..f4903d9 100644 --- a/disco/gateway/encoding/base.py +++ b/disco/gateway/encoding/base.py @@ -1,7 +1,9 @@ from websocket import ABNF +from holster.interface import Interface -class BaseEncoder(object): + +class BaseEncoder(Interface): TYPE = None OPCODE = ABNF.OPCODE_TEXT diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 2129e53..d1e3dd5 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -5,7 +5,7 @@ import six from disco.types.user import User, Presence from disco.types.channel import Channel -from disco.types.message import Message +from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role, Emoji @@ -524,3 +524,45 @@ class WebhooksUpdate(GatewayEvent): """ channel_id = Field(snowflake) guild_id = Field(snowflake) + + +class MessageReactionAdd(GatewayEvent): + """ + Sent when a reaction is added to a message. + + Attributes + ---------- + channel_id : snowflake + The channel ID the message is in. + messsage_id : snowflake + The ID of the message for which the reaction was added too. + user_id : snowflake + The ID of the user who added the reaction. + emoji : :class:`disco.types.message.MessageReactionEmoji` + The emoji which was added. + """ + channel_id = Field(snowflake) + message_id = Field(snowflake) + user_id = Field(snowflake) + emoji = Field(MessageReactionEmoji) + + +class MessageReactionRemove(GatewayEvent): + """ + Sent when a reaction is removed from a message. + + Attributes + ---------- + channel_id : snowflake + The channel ID the message is in. + messsage_id : snowflake + The ID of the message for which the reaction was removed from. + user_id : snowflake + The ID of the user who originally added the reaction. + emoji : :class:`disco.types.message.MessageReactionEmoji` + The emoji which was removed. + """ + channel_id = Field(snowflake) + message_id = Field(snowflake) + user_id = Field(snowflake) + emoji = Field(MessageReactionEmoji) diff --git a/disco/types/message.py b/disco/types/message.py index 15f751d..509ad2b 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -19,6 +19,17 @@ MessageType = Enum( ) +class MessageReactionEmoji(SlottedModel): + id = Field(snowflake) + name = Field(text) + + +class MessageReaction(SlottedModel): + emoji = Field(MessageReactionEmoji) + count = Field(int) + me = Field(bool) + + class MessageEmbedFooter(SlottedModel): text = Field(text) icon_url = Field(text) @@ -170,6 +181,7 @@ class Message(SlottedModel): mention_roles = Field(listof(snowflake)) embeds = Field(listof(MessageEmbed)) attachments = Field(dictof(MessageAttachment, key='id')) + reactions = Field(listof(MessageReaction)) def __str__(self): return ''.format(self.id, self.channel_id) From ef4e87f7fbfac4a8b8729af76e5cff5c21a21777 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 28 Oct 2016 01:23:19 -0500 Subject: [PATCH 15/91] Better interface to reactions, etc cleanup --- disco/api/http.py | 8 ++++++++ disco/api/ratelimit.py | 13 ++++++++++--- disco/cli.py | 6 +++--- disco/client.py | 4 ++-- disco/gateway/sharder.py | 6 ++---- disco/types/channel.py | 3 +++ disco/types/guild.py | 15 +++++++++------ disco/types/message.py | 37 ++++++++++++++++++++++++++++++++++++- disco/util/logging.py | 31 ++++++++++++++++++++----------- 9 files changed, 93 insertions(+), 30 deletions(-) diff --git a/disco/api/http.py b/disco/api/http.py index 9f90649..8b6fdbb 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -18,6 +18,12 @@ HTTPMethod = Enum( ) +def to_bytes(obj): + if isinstance(obj, six.text_type): + return obj.encode('utf-8') + return obj + + class Routes(object): """ Simple Python object-enum of all method/url route combinations available to @@ -194,6 +200,7 @@ class HTTPClient(LoggingClass): kwargs['headers'] = self.headers # Build the bucket URL + args = {to_bytes(k): to_bytes(v) for k, v in six.iteritems(args)} filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in six.iteritems(args)} bucket = (route[0].value, route[1].format(**filtered)) @@ -202,6 +209,7 @@ class HTTPClient(LoggingClass): # Make the actual request url = self.BASE_URL + route[1].format(**args) + self.log.info('%s %s', route[0].value, url) r = requests.request(route[0].value, url, **kwargs) # Update rate limiter diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py index 420d6f3..244b291 100644 --- a/disco/api/ratelimit.py +++ b/disco/api/ratelimit.py @@ -1,8 +1,10 @@ import time import gevent +from disco.util.logging import LoggingClass -class RouteState(object): + +class RouteState(LoggingClass): """ An object which stores ratelimit state for a given method/url route combination (as specified in :class:`disco.api.http.Routes`). @@ -36,6 +38,9 @@ class RouteState(object): self.update(response) + def __repr__(self): + return ''.format(' '.join(self.route)) + @property def chilled(self): """ @@ -92,12 +97,14 @@ class RouteState(object): raise Exception('Cannot cooldown for negative time period; check clock sync') self.event = gevent.event.Event() - gevent.sleep((self.reset_time - time.time()) + .5) + delay = (self.reset_time - time.time()) + .5 + self.log.debug('Cooling down bucket %s for %s seconds', self, delay) + gevent.sleep(delay) self.event.set() self.event = None -class RateLimiter(object): +class RateLimiter(LoggingClass): """ A in-memory store of ratelimit states for all routes we've ever called. diff --git a/disco/cli.py b/disco/cli.py index 01c4aaf..4b9d95b 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -41,7 +41,7 @@ def disco_main(run=False): from disco.client import Client, ClientConfig from disco.bot import Bot, BotConfig from disco.util.token import is_valid_token - from holster.log import set_logging_levels + from disco.util.logging import setup_logging if os.path.exists(args.config): config = ClientConfig.from_file(args.config) @@ -61,8 +61,8 @@ def disco_main(run=False): AutoSharder(config).run() return - logging.basicConfig(level=logging.INFO) - set_logging_levels() + # TODO: make configurable + setup_logging(level=logging.INFO) client = Client(config) diff --git a/disco/client.py b/disco/client.py index 66bcfb3..01b6635 100644 --- a/disco/client.py +++ b/disco/client.py @@ -13,7 +13,7 @@ from disco.util.logging import LoggingClass from disco.util.backdoor import DiscoBackdoorServer -class ClientConfig(LoggingClass, Config): +class ClientConfig(Config): """ Configuration for the :class:`Client`. @@ -46,7 +46,7 @@ class ClientConfig(LoggingClass, Config): encoder = 'json' -class Client(object): +class Client(LoggingClass): """ Class representing the base entry point that should be used in almost all implementation cases. This class wraps the functionality of both the REST API diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index 5d98ad6..c401a3e 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -5,22 +5,20 @@ import gevent import logging import marshal -from holster.log import set_logging_levels - from disco.client import Client from disco.bot import Bot, BotConfig from disco.api.client import APIClient from disco.gateway.ipc import GIPCProxy +from disco.util.logging import setup_logging from disco.util.snowflake import calculate_shard from disco.util.serializer import dump_function, load_function def run_shard(config, id, pipe): - logging.basicConfig( + setup_logging( level=logging.INFO, format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) ) - set_logging_levels() config.shard_id = id client = Client(config) diff --git a/disco/types/channel.py b/disco/types/channel.py index 4c2a54c..76e1420 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -186,6 +186,9 @@ class Channel(SlottedModel, Permissible): """ return MessageIterator(self.client, self, **kwargs) + def get_message(self, message): + return self.client.api.channels_messages_get(self.id, to_snowflake(message)) + def get_invites(self): """ Returns diff --git a/disco/types/guild.py b/disco/types/guild.py index f6e1212..14b781a 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -10,6 +10,7 @@ from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, tex from disco.types.user import User, Presence from disco.types.voice import VoiceState from disco.types.channel import Channel +from disco.types.message import Emoji from disco.types.permissions import PermissionValue, Permissions, Permissible @@ -22,7 +23,9 @@ VerificationLevel = Enum( ) -class GuildSubType(SlottedModel): +class GuildSubType(object): + __slots__ = [] + guild_id = Field(None) @cached_property @@ -30,7 +33,7 @@ class GuildSubType(SlottedModel): return self.client.state.guilds.get(self.guild_id) -class Emoji(GuildSubType): +class GuildEmoji(Emoji, GuildSubType): """ An emoji object @@ -54,7 +57,7 @@ class Emoji(GuildSubType): roles = Field(listof(snowflake)) -class Role(GuildSubType): +class Role(SlottedModel, GuildSubType): """ A role object @@ -95,7 +98,7 @@ class Role(GuildSubType): return '<@{}>'.format(self.id) -class GuildMember(GuildSubType): +class GuildMember(SlottedModel, GuildSubType): """ A GuildMember object @@ -222,7 +225,7 @@ class Guild(SlottedModel, Permissible): All of the guild's channels. roles : dict(snowflake, :class:`Role`) All of the guild's roles. - emojis : dict(snowflake, :class:`Emoji`) + emojis : dict(snowflake, :class:`GuildEmoji`) All of the guild's emojis. voice_states : dict(str, :class:`disco.types.voice.VoiceState`) All of the guild's voice states. @@ -243,7 +246,7 @@ class Guild(SlottedModel, Permissible): members = Field(dictof(GuildMember, key='id')) channels = Field(dictof(Channel, key='id')) roles = Field(dictof(Role, key='id')) - emojis = Field(dictof(Emoji, key='id')) + emojis = Field(dictof(GuildEmoji, key='id')) voice_states = Field(dictof(VoiceState, key='session_id')) member_count = Field(int) presences = Field(listof(Presence)) diff --git a/disco/types/message.py b/disco/types/message.py index 509ad2b..1df67c5 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -19,10 +19,24 @@ MessageType = Enum( ) -class MessageReactionEmoji(SlottedModel): +class Emoji(SlottedModel): id = Field(snowflake) name = Field(text) + def __eq__(self, other): + if isinstance(other, Emoji): + return self.id == other.id and self.name == other.name + raise NotImplementedError + + def to_string(self): + if self.id: + return '{}:{}'.format(self.name, self.id) + return self.name + + +class MessageReactionEmoji(Emoji): + pass + class MessageReaction(SlottedModel): emoji = Field(MessageReactionEmoji) @@ -261,6 +275,27 @@ class Message(SlottedModel): """ return self.client.api.channels_messages_delete(self.channel_id, self.id) + def create_reaction(self, emoji): + if isinstance(emoji, Emoji): + emoji = emoji.to_string() + self.client.api.channels_messages_reactions_create( + self.channel_id, + self.id, + emoji) + + def delete_reaction(self, emoji, user=None): + if isinstance(emoji, Emoji): + emoji = emoji.to_string() + + if user: + user = to_snowflake(user) + + self.client.api.channels_messages_reactions_delete( + self.channel_id, + self.id, + emoji, + user) + def is_mentioned(self, entity): """ Returns diff --git a/disco/util/logging.py b/disco/util/logging.py index 7feca4d..5ce9498 100644 --- a/disco/util/logging.py +++ b/disco/util/logging.py @@ -3,15 +3,24 @@ from __future__ import absolute_import import logging +LEVEL_OVERRIDES = { + 'requests': logging.WARNING +} + + +def setup_logging(**kwargs): + logging.basicConfig(**kwargs) + for logger, level in LEVEL_OVERRIDES.items(): + logging.getLogger(logger).setLevel(level) + + class LoggingClass(object): - def __init__(self): - self.log = logging.getLogger(self.__class__.__name__) - - def log_on_error(self, msg, f): - def _f(*args, **kwargs): - try: - return f(*args, **kwargs) - except: - self.log.exception(msg) - raise - return _f + __slots__ = ['_log'] + + @property + def log(self): + try: + return self._log + except AttributeError: + self._log = logging.getLogger(self.__class__.__name__) + return self._log From 6a99eab1a204f5674d44bf0edc68cb679756260b Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 28 Oct 2016 19:19:24 -0500 Subject: [PATCH 16/91] Cleanup event subtype proxying, fix GuildSubType, etc --- disco/gateway/events.py | 51 +++++++++++++++++++++++++++++++++++------ disco/types/channel.py | 4 ++++ disco/types/guild.py | 34 ++++++++++++++++----------- disco/types/user.py | 4 ++-- 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index d1e3dd5..9afb9c0 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -62,11 +62,9 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): return cls(obj, client) def __getattr__(self, name): - if hasattr(self, '_wraps_model'): - modname, _ = self._wraps_model - if hasattr(self, modname) and hasattr(getattr(self, modname), name): - return getattr(getattr(self, modname), name) - raise AttributeError(name) + if hasattr(self, '_proxy'): + return getattr(getattr(self, self._proxy), name) + return object.__getattribute__(self, name) def debug(func=None): @@ -93,6 +91,14 @@ def wraps_model(model, alias=None): cls._fields[alias] = Field(model) cls._fields[alias].set_name(alias) cls._wraps_model = (alias, model) + cls._proxy = alias + return cls + return deco + + +def proxy(field): + def deco(cls): + cls._proxy = field return cls return deco @@ -252,6 +258,10 @@ class GuildBanAdd(GatewayEvent): """ guild_id = Field(snowflake) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + @wraps_model(User) class GuildBanRemove(GuildBanAdd): @@ -266,6 +276,10 @@ class GuildBanRemove(GuildBanAdd): The user being unbanned from the guild. """ + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + class GuildEmojisUpdate(GatewayEvent): """ @@ -308,6 +322,10 @@ class GuildMembersChunk(GatewayEvent): guild_id = Field(snowflake) members = Field(listof(GuildMember)) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + @wraps_model(GuildMember, alias='member') class GuildMemberAdd(GatewayEvent): @@ -321,6 +339,7 @@ class GuildMemberAdd(GatewayEvent): """ +@proxy('user') class GuildMemberRemove(GatewayEvent): """ Sent when a user leaves a guild (via leaving, kicking, or banning). @@ -332,8 +351,12 @@ class GuildMemberRemove(GatewayEvent): user : :class:`disco.types.user.User` The user who was removed from the guild. """ - guild_id = Field(snowflake) user = Field(User) + guild_id = Field(snowflake) + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) @wraps_model(GuildMember, alias='member') @@ -348,6 +371,7 @@ class GuildMemberUpdate(GatewayEvent): """ +@proxy('role') class GuildRoleCreate(GatewayEvent): """ Sent when a role is created. @@ -359,10 +383,15 @@ class GuildRoleCreate(GatewayEvent): role : :class:`disco.types.guild.Role` The role that was created. """ - guild_id = Field(snowflake) role = Field(Role) + guild_id = Field(snowflake) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + +@proxy('role') class GuildRoleUpdate(GuildRoleCreate): """ Sent when a role is updated. @@ -375,6 +404,10 @@ class GuildRoleUpdate(GuildRoleCreate): The role that was created. """ + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + class GuildRoleDelete(GatewayEvent): """ @@ -390,6 +423,10 @@ class GuildRoleDelete(GatewayEvent): guild_id = Field(snowflake) role_id = Field(snowflake) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + @wraps_model(Message) class MessageCreate(GatewayEvent): diff --git a/disco/types/channel.py b/disco/types/channel.py index 76e1420..836c315 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -144,6 +144,10 @@ class Channel(SlottedModel, Permissible): return base + @property + def mention(self): + return '<#{}>'.format(self.id) + @property def is_guild(self): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index 14b781a..b058d71 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -23,17 +23,7 @@ VerificationLevel = Enum( ) -class GuildSubType(object): - __slots__ = [] - - guild_id = Field(None) - - @cached_property - def guild(self): - return self.client.state.guilds.get(self.guild_id) - - -class GuildEmoji(Emoji, GuildSubType): +class GuildEmoji(Emoji): """ An emoji object @@ -56,8 +46,12 @@ class GuildEmoji(Emoji, GuildSubType): managed = Field(bool) roles = Field(listof(snowflake)) + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + -class Role(SlottedModel, GuildSubType): +class Role(SlottedModel): """ A role object @@ -87,6 +81,9 @@ class Role(SlottedModel, GuildSubType): position = Field(int) mentionable = Field(bool) + def __str__(self): + return self.name + def delete(self): self.guild.delete_role(self) @@ -97,8 +94,12 @@ class Role(SlottedModel, GuildSubType): def mention(self): return '<@{}>'.format(self.id) + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + -class GuildMember(SlottedModel, GuildSubType): +class GuildMember(SlottedModel): """ A GuildMember object @@ -127,6 +128,9 @@ class GuildMember(SlottedModel, GuildSubType): joined_at = Field(str) roles = Field(listof(snowflake)) + def __str__(self): + return self.user.__str__() + def get_voice_state(self): """ Returns @@ -186,6 +190,10 @@ class GuildMember(SlottedModel, GuildSubType): """ return self.user.id + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + class Guild(SlottedModel, Permissible): """ diff --git a/disco/types/user.py b/disco/types/user.py index cad2cb8..15507d9 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -18,10 +18,10 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): def mention(self): return '<@{}>'.format(self.id) - def to_string(self): + def __str__(self): return '{}#{}'.format(self.username, self.discriminator) - def __str__(self): + def __repr__(self): return ''.format(self.id, self.to_string()) def on_create(self): From 2a311cc33649408ad61e142ff5447104e77aa3ba Mon Sep 17 00:00:00 2001 From: andrei Date: Fri, 28 Oct 2016 21:38:59 -0700 Subject: [PATCH 17/91] Code cleanliness pass --- disco/api/client.py | 4 ++-- disco/api/http.py | 9 ++++++--- disco/bot/bot.py | 5 ++++- disco/bot/command.py | 21 ++++++++++++++------- disco/bot/plugin.py | 23 +++++++++++++++++------ disco/bot/providers/disk.py | 1 + disco/bot/providers/redis.py | 5 +++-- disco/bot/providers/rocksdb.py | 6 ++++-- disco/client.py | 4 ++-- disco/gateway/client.py | 6 +++--- disco/gateway/encoding/json.py | 2 -- disco/gateway/events.py | 4 ++-- disco/gateway/sharder.py | 26 ++++++++++++++------------ disco/types/base.py | 10 +++++----- disco/types/channel.py | 12 ++++++------ disco/types/guild.py | 2 ++ disco/types/message.py | 17 ++++++++--------- disco/types/permissions.py | 3 +++ disco/types/webhook.py | 6 ++++-- disco/util/config.py | 2 +- disco/util/snowflake.py | 2 +- disco/voice/client.py | 14 +++++++------- 22 files changed, 109 insertions(+), 75 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index 869d49c..df13194 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -88,8 +88,8 @@ class APIClient(LoggingClass): def channels_messages_modify(self, channel, message, content): r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, - dict(channel=channel, message=message), - json={'content': content}) + dict(channel=channel, message=message), + json={'content': content}) return Message.create(self.client, r.json()) def channels_messages_delete(self, channel, message): diff --git a/disco/api/http.py b/disco/api/http.py index 8b6fdbb..7fafbf6 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -48,7 +48,8 @@ class Routes(object): CHANNELS_MESSAGES_REACTIONS_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}/reactions/{emoji}') CHANNELS_MESSAGES_REACTIONS_CREATE = (HTTPMethod.PUT, CHANNELS + '/messages/{message}/reactions/{emoji}/@me') CHANNELS_MESSAGES_REACTIONS_DELETE_ME = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/@me') - CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/{user}') + CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE, + CHANNELS + '/messages/{message}/reactions/{emoji}/{user}') CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}') CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}') CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites') @@ -222,13 +223,15 @@ class HTTPClient(LoggingClass): raise APIException('Request failed', r.status_code, r.content) else: if r.status_code == 429: - self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync') + self.log.warning( + 'Request responded w/ 429, retrying (but this should not happen, check your clock sync') # If we hit the max retries, throw an error retry += 1 if retry > self.MAX_RETRIES: self.log.error('Failing request, hit max retries') - raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) + raise APIException( + 'Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) backoff = self.random_backoff() self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format( diff --git a/disco/bot/bot.py b/disco/bot/bot.py index c4f391b..ba6ec25 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -245,7 +245,7 @@ class Bot(object): mention_roles = [] if msg.guild: mention_roles = list(filter(lambda r: msg.is_mentioned(r), - msg.guild.get_member(self.client.state.me).roles)) + msg.guild.get_member(self.client.state.me).roles)) if not any(( self.config.commands_mention_rules['user'] and mention_direct, @@ -370,6 +370,9 @@ class Bot(object): Plugin class to initialize and load. config : Optional The configuration to load the plugin with. + ctx : Optional[dict] + Context (previous state) to pass the plugin. Usually used along w/ + unload. """ if cls.__name__ in self.plugins: raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) diff --git a/disco/bot/command.py b/disco/bot/command.py index f2efbb1..137a36c 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -107,16 +107,23 @@ class Command(object): self.plugin = plugin self.func = func self.triggers = [trigger] + + self.args = None + self.level = None + self.group = None + self.is_regex = None + self.oob = False + self.update(*args, **kwargs) def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False): self.triggers += aliases or [] - def resolve_role(ctx, id): - return ctx.msg.guild.roles.get(id) + def resolve_role(ctx, rid): + return ctx.msg.guild.roles.get(rid) - def resolve_user(ctx, id): - return ctx.msg.mentions.get(id) + def resolve_user(ctx, uid): + return ctx.msg.mentions.get(uid) self.args = ArgumentSet.from_string(args or '', { 'mention': self.mention_type([resolve_role, resolve_user]), @@ -136,17 +143,17 @@ class Command(object): if not res: raise TypeError('Invalid mention: {}'.format(i)) - id = int(res.group(1)) + mid = int(res.group(1)) for getter in getters: - obj = getter(ctx, id) + obj = getter(ctx, mid) if obj: return obj if force: raise TypeError('Cannot resolve mention: {}'.format(id)) - return id + return mid return _f @cached_property diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 40b6e49..6ea616e 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -154,6 +154,14 @@ class Plugin(LoggingClass, PluginDeco): self.storage = bot.storage self.config = config + # General declartions + self.listeners = [] + self.commands = {} + self.schedules = {} + self.greenlets = weakref.WeakSet() + self._pre = {} + self._post = {} + # This is an array of all meta functions we sniff at init self.meta_funcs = [] @@ -248,7 +256,7 @@ class Plugin(LoggingClass, PluginDeco): return True - def register_listener(self, func, what, desc, priority=Priority.NONE, conditional=None): + def register_listener(self, func, what, desc, **kwargs): """ Registers a listener @@ -260,15 +268,13 @@ class Plugin(LoggingClass, PluginDeco): The function to be registered. desc The descriptor of the event/packet. - priority : Priority - The priority of this listener. """ func = functools.partial(self._dispatch, 'listener', func) if what == 'event': - li = self.bot.client.events.on(desc, func, priority=priority, conditional=conditional) + li = self.bot.client.events.on(desc, func, **kwargs) elif what == 'packet': - li = self.bot.client.packets.on(desc, func, priority=priority, conditional=conditional) + li = self.bot.client.packets.on(desc, func, **kwargs) else: raise Exception('Invalid listener what: {}'.format(what)) @@ -305,8 +311,13 @@ class Plugin(LoggingClass, PluginDeco): The function to be registered. interval : int Interval (in seconds) to repeat the function on. + repeat : bool + Whether this schedule is repeating (or one time). + init : bool + Whether to run this schedule once immediatly, or wait for the first + scheduled iteration. """ - def repeat(): + def func(): if init: func() diff --git a/disco/bot/providers/disk.py b/disco/bot/providers/disk.py index 5cf1ca3..af259e1 100644 --- a/disco/bot/providers/disk.py +++ b/disco/bot/providers/disk.py @@ -13,6 +13,7 @@ class DiskProvider(BaseProvider): self.fsync = config.get('fsync', False) self.fsync_changes = config.get('fsync_changes', 1) + self.autosave_task = None self.change_count = 0 def autosave_loop(self, interval): diff --git a/disco/bot/providers/redis.py b/disco/bot/providers/redis.py index 239ac9c..1e5150a 100644 --- a/disco/bot/providers/redis.py +++ b/disco/bot/providers/redis.py @@ -10,8 +10,9 @@ from .base import BaseProvider, SEP_SENTINEL class RedisProvider(BaseProvider): def __init__(self, config): - self.config = config + super(RedisProvider, self).__init__(config) self.format = config.get('format', 'pickle') + self.conn = None def load(self): self.conn = redis.Redis( @@ -39,5 +40,5 @@ class RedisProvider(BaseProvider): def set(self, key, value): self.conn.set(key, Serializer.dumps(self.format, value)) - def delete(self, key, value): + def delete(self, key): self.conn.delete(key) diff --git a/disco/bot/providers/rocksdb.py b/disco/bot/providers/rocksdb.py index 0062d79..986268d 100644 --- a/disco/bot/providers/rocksdb.py +++ b/disco/bot/providers/rocksdb.py @@ -12,11 +12,13 @@ from .base import BaseProvider, SEP_SENTINEL class RocksDBProvider(BaseProvider): def __init__(self, config): - self.config = config + super(RocksDBProvider, self).__init__(config) self.format = config.get('format', 'pickle') self.path = config.get('path', 'storage.db') + self.db = None - def k(self, k): + @staticmethod + def k(k): return bytes(k) if six.PY3 else str(k.encode('utf-8')) def load(self): diff --git a/disco/client.py b/disco/client.py index 01b6635..6bda4ef 100644 --- a/disco/client.py +++ b/disco/client.py @@ -98,8 +98,8 @@ class Client(LoggingClass): } self.manhole = DiscoBackdoorServer(self.config.manhole_bind, - banner='Disco Manhole', - localf=lambda: self.manhole_locals) + banner='Disco Manhole', + localf=lambda: self.manhole_locals) self.manhole.start() def update_presence(self, game=None, status=None, afk=False, since=0.0): diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 98d71d6..46d5fa6 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -81,15 +81,15 @@ class GatewayClient(LoggingClass): self.log.debug('Dispatching %s', obj.__class__.__name__) self.client.events.emit(obj.__class__.__name__, obj) - def handle_heartbeat(self, packet): + def handle_heartbeat(self, _): self._send(OPCode.HEARTBEAT, self.seq) - def handle_reconnect(self, packet): + def handle_reconnect(self, _): self.log.warning('Received RECONNECT request, forcing a fresh reconnect') self.session_id = None self.ws.close() - def handle_invalid_session(self, packet): + def handle_invalid_session(self, _): self.log.warning('Recieved INVALID_SESSION, forcing a fresh reconnect') self.session_id = None self.ws.close() diff --git a/disco/gateway/encoding/json.py b/disco/gateway/encoding/json.py index 8810198..8550ac5 100644 --- a/disco/gateway/encoding/json.py +++ b/disco/gateway/encoding/json.py @@ -1,7 +1,5 @@ from __future__ import absolute_import, print_function -import six - try: import ujson as json except ImportError: diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 9afb9c0..6b612ea 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -16,8 +16,8 @@ EVENTS_MAP = {} class GatewayEventMeta(ModelMeta): - def __new__(cls, name, parents, dct): - obj = super(GatewayEventMeta, cls).__new__(cls, name, parents, dct) + def __new__(mcs, name, parents, dct): + obj = super(GatewayEventMeta, mcs).__new__(mcs, name, parents, dct) if name != 'GatewayEvent': EVENTS_MAP[inflection.underscore(name).upper()] = obj diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index c401a3e..45e756a 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -5,6 +5,8 @@ import gevent import logging import marshal +from six.moves import range + from disco.client import Client from disco.bot import Bot, BotConfig from disco.api.client import APIClient @@ -14,13 +16,13 @@ from disco.util.snowflake import calculate_shard from disco.util.serializer import dump_function, load_function -def run_shard(config, id, pipe): +def run_shard(config, shard_id, pipe): setup_logging( level=logging.INFO, - format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) + format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(shard_id) ) - config.shard_id = id + config.shard_id = shard_id client = Client(config) bot = Bot(client, BotConfig(config.bot)) bot.sharder = GIPCProxy(bot, pipe) @@ -34,8 +36,8 @@ class ShardHelper(object): self.bot = bot def keys(self): - for id in xrange(self.count): - yield id + for sid in range(self.count): + yield sid def on(self, id, func): if id == self.bot.client.config.shard_id: @@ -49,8 +51,8 @@ class ShardHelper(object): pool = gevent.pool.Pool(self.count) return dict(zip(range(self.count), pool.imap(lambda i: self.on(i, func).wait(timeout=timeout), range(self.count)))) - def for_id(self, id, func): - shard = calculate_shard(self.count, id) + def for_id(self, sid, func): + shard = calculate_shard(self.count, sid) return self.on(shard, func) @@ -63,9 +65,9 @@ class AutoSharder(object): if self.config.shard_count > 1: self.config.shard_count = 10 - def run_on(self, id, raw): + def run_on(self, sid, raw): func = load_function(raw) - return self.shards[id].execute(func).wait(timeout=15) + return self.shards[sid].execute(func).wait(timeout=15) def run(self): for shard in range(self.config.shard_count): @@ -80,7 +82,7 @@ class AutoSharder(object): format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) ) - def start_shard(self, id): + def start_shard(self, sid): cpipe, ppipe = gipc.pipe(duplex=True, encoder=marshal.dumps, decoder=marshal.loads) - gipc.start_process(run_shard, (self.config, id, cpipe)) - self.shards[id] = GIPCProxy(self, ppipe) + gipc.start_process(run_shard, (self.config, sid, cpipe)) + self.shards[sid] = GIPCProxy(self, ppipe) diff --git a/disco/types/base.py b/disco/types/base.py index 78536a3..1fd69fc 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -179,7 +179,7 @@ def with_equality(field): def with_hash(field): class T(object): - def __hash__(self, other): + def __hash__(self): return hash(getattr(self, field)) return T @@ -190,7 +190,7 @@ SlottedModel = None class ModelMeta(type): - def __new__(cls, name, parents, dct): + def __new__(mcs, name, parents, dct): fields = {} for parent in parents: @@ -217,7 +217,7 @@ class ModelMeta(type): dct = {k: v for k, v in six.iteritems(dct) if k not in fields} dct['_fields'] = fields - return super(ModelMeta, cls).__new__(cls, name, parents, dct) + return super(ModelMeta, mcs).__new__(mcs, name, parents, dct) class AsyncChainable(object): @@ -280,8 +280,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): return inst @classmethod - def create_map(cls, client, data): - return list(map(functools.partial(cls.create, client), data)) + def create_map(cls, client, data, **kwargs): + return list(map(functools.partial(cls.create, client, **kwargs), data)) @classmethod def attach(cls, it, data): diff --git a/disco/types/channel.py b/disco/types/channel.py index 836c315..57d241d 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -57,11 +57,11 @@ class PermissionOverwrite(ChannelSubType): def create(cls, channel, entity, allow=0, deny=0): from disco.types.guild import Role - type = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER + ptype = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER return cls( client=channel.client, id=entity.id, - type=type, + type=ptype, allow=allow, deny=deny, channel_id=channel.id @@ -69,10 +69,10 @@ class PermissionOverwrite(ChannelSubType): def save(self): self.client.api.channels_permissions_modify(self.channel_id, - self.id, - self.allow.value or 0, - self.deny.value or 0, - self.type.name) + self.id, + self.allow.value or 0, + self.deny.value or 0, + self.type.name) return self def delete(self): diff --git a/disco/types/guild.py b/disco/types/guild.py index b058d71..d8a2963 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -41,6 +41,7 @@ class GuildEmoji(Emoji): Roles this emoji is attached to. """ id = Field(snowflake) + guild_id = Field(snowflake) name = Field(text) require_colons = Field(bool) managed = Field(bool) @@ -73,6 +74,7 @@ class Role(SlottedModel): The position of this role in the hierarchy. """ id = Field(snowflake) + guild_id = Field(snowflake) name = Field(text) hoist = Field(bool) managed = Field(bool) diff --git a/disco/types/message.py b/disco/types/message.py index 1df67c5..ce9822f 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -303,8 +303,8 @@ class Message(SlottedModel): bool Whether the give entity was mentioned. """ - id = to_snowflake(entity) - return id in self.mentions or id in self.mention_roles + entity = to_snowflake(entity) + return entity in self.mentions or entity in self.mention_roles @cached_property def without_mentions(self): @@ -340,11 +340,11 @@ class Message(SlottedModel): return def replace(match): - id = match.group(0) - if id in self.mention_roles: - return role_replace(id) + oid = match.group(0) + if oid in self.mention_roles: + return role_replace(oid) else: - return user_replace(self.mentions.get(id)) + return user_replace(self.mentions.get(oid)) return re.sub('<@!?([0-9]+)>', replace, self.content) @@ -376,14 +376,13 @@ class MessageTable(object): data = self.sep.lstrip() for idx, col in enumerate(cols): - padding = ' ' * ((self.size_index[idx] - len(col))) + padding = ' ' * (self.size_index[idx] - len(col)) data += col + padding + self.sep return data.rstrip() def compile(self): - data = [] - data.append(self.compile_one(self.header)) + data = [self.compile_one(self.header)] if self.header_break: data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1)) diff --git a/disco/types/permissions.py b/disco/types/permissions.py index aa7260c..7afb3f7 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -107,6 +107,9 @@ class PermissionValue(object): class Permissible(object): __slots__ = [] + def get_permissions(self): + raise NotImplementedError + def can(self, user, *args): perms = self.get_permissions(user) return perms.administrator or perms.can(*args) diff --git a/disco/types/webhook.py b/disco/types/webhook.py index 3afdd3f..4a630d3 100644 --- a/disco/types/webhook.py +++ b/disco/types/webhook.py @@ -32,12 +32,14 @@ class Webhook(SlottedModel): else: return self.client.api.webhooks_modify(self.id, name, avatar) - def execute(self, content=None, username=None, avatar_url=None, tts=False, file=None, embeds=[], wait=False): + def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False): + # TODO: support file stuff properly + return self.client.api.webhooks_token_execute(self.id, self.token, { 'content': content, 'username': username, 'avatar_url': avatar_url, 'tts': tts, - 'file': file, + 'file': fobj, 'embeds': [i.to_dict() for i in embeds], }, wait) diff --git a/disco/util/config.py b/disco/util/config.py index 29147c2..30d2996 100644 --- a/disco/util/config.py +++ b/disco/util/config.py @@ -29,7 +29,7 @@ class Config(object): return inst def from_prefix(self, prefix): - prefix = prefix + '_' + prefix += '_' obj = {} for k, v in six.iteritems(self.__dict__): diff --git a/disco/util/snowflake.py b/disco/util/snowflake.py index 241e2de..b2f512f 100644 --- a/disco/util/snowflake.py +++ b/disco/util/snowflake.py @@ -17,7 +17,7 @@ def to_unix(snowflake): def to_unix_ms(snowflake): - return ((int(snowflake) >> 22) + DISCORD_EPOCH) + return (int(snowflake) >> 22) + DISCORD_EPOCH def to_snowflake(i): diff --git a/disco/voice/client.py b/disco/voice/client.py index 568174f..69fe9d5 100644 --- a/disco/voice/client.py +++ b/disco/voice/client.py @@ -102,6 +102,7 @@ class VoiceClient(LoggingClass): self.endpoint = None self.ssrc = None self.port = None + self.udp = None self.update_listener = None @@ -149,7 +150,7 @@ class VoiceClient(LoggingClass): } }) - def on_voice_sdp(self, data): + def on_voice_sdp(self, _): # Toggle speaking state so clients learn of our SSRC self.set_speaking(True) self.set_speaking(False) @@ -178,19 +179,18 @@ class VoiceClient(LoggingClass): ) self.ws.run_forever() - def on_message(self, ws, msg): + def on_message(self, _, msg): try: data = self.encoder.decode(msg) + self.packets.emit(VoiceOPCode[data['op']], data['d']) except: self.log.exception('Failed to parse voice gateway message: ') - self.packets.emit(VoiceOPCode[data['op']], data['d']) - - def on_error(self, ws, err): + def on_error(self, _, err): # TODO self.log.warning('Voice websocket error: {}'.format(err)) - def on_open(self, ws): + def on_open(self, _): self.send(VoiceOPCode.IDENTIFY, { 'server_id': self.channel.guild_id, 'user_id': self.client.state.me.id, @@ -198,7 +198,7 @@ class VoiceClient(LoggingClass): 'token': self.token }) - def on_close(self, ws, code, error): + def on_close(self, _, code, error): # TODO self.log.warning('Voice websocket disconnected (%s, %s)', code, error) From 9891b900d6c28cb1df3f07a6bd732a741ee1e029 Mon Sep 17 00:00:00 2001 From: andrei Date: Sun, 30 Oct 2016 19:49:45 -0700 Subject: [PATCH 18/91] Modeling improvements, couple other fixes Modeling fields are now drastically better, no more dictof/listof bullshit, we now properly have ListField/DictField/etc. - Fix setting self nickname --- disco/api/client.py | 3 + disco/api/http.py | 1 + disco/gateway/events.py | 24 +++++--- disco/types/base.py | 126 ++++++++++++++++++++++------------------ disco/types/channel.py | 9 ++- disco/types/guild.py | 40 ++++++++----- disco/types/message.py | 17 +++--- 7 files changed, 131 insertions(+), 89 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index df13194..27152e4 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -196,6 +196,9 @@ class APIClient(LoggingClass): def guilds_members_modify(self, guild, member, **kwargs): self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) + def guilds_members_me_nick(self, guild, nick): + self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick}) + def guilds_members_kick(self, guild, member): self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member)) diff --git a/disco/api/http.py b/disco/api/http.py index 7fafbf6..31035e4 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -71,6 +71,7 @@ class Routes(object): GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') + GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6b612ea..6799795 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -9,7 +9,7 @@ from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role, Emoji -from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime +from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} @@ -89,7 +89,7 @@ def wraps_model(model, alias=None): def deco(cls): cls._fields[alias] = Field(model) - cls._fields[alias].set_name(alias) + cls._fields[alias].name = alias cls._wraps_model = (alias, model) cls._proxy = alias return cls @@ -124,8 +124,8 @@ class Ready(GatewayEvent): version = Field(int, alias='v') session_id = Field(str) user = Field(User) - guilds = Field(listof(Guild)) - private_channels = Field(listof(Channel)) + guilds = ListField(Guild) + private_channels = ListField(Guild) class Resumed(GatewayEvent): @@ -293,7 +293,7 @@ class GuildEmojisUpdate(GatewayEvent): The new set of emojis for the guild """ guild_id = Field(snowflake) - emojis = Field(listof(Emoji)) + emojis = ListField(Emoji) class GuildIntegrationsUpdate(GatewayEvent): @@ -320,7 +320,7 @@ class GuildMembersChunk(GatewayEvent): The chunk of members. """ guild_id = Field(snowflake) - members = Field(listof(GuildMember)) + members = ListField(GuildMember) @property def guild(self): @@ -466,6 +466,14 @@ class MessageDelete(GatewayEvent): id = Field(snowflake) channel_id = Field(snowflake) + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + class MessageDeleteBulk(GatewayEvent): """ @@ -479,7 +487,7 @@ class MessageDeleteBulk(GatewayEvent): List of messages being deleted in the channel. """ channel_id = Field(snowflake) - ids = Field(listof(snowflake)) + ids = ListField(snowflake) @wraps_model(Presence) @@ -497,7 +505,7 @@ class PresenceUpdate(GatewayEvent): List of roles the user from the presence is part of. """ guild_id = Field(snowflake) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) class TypingStart(GatewayEvent): diff --git a/disco/types/base.py b/disco/types/base.py index 1fd69fc..3bea43a 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -1,4 +1,5 @@ import six +import sys import gevent import inspect import functools @@ -19,49 +20,31 @@ class ConversionError(Exception): def __init__(self, field, raw, e): super(ConversionError, self).__init__( 'Failed to convert `{}` (`{}`) to {}: {}'.format( - str(raw)[:144], field.src_name, field.typ, e)) + str(raw)[:144], field.src_name, field.deserializer, e)) -class FieldType(object): - def __init__(self, typ): - if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model): - self.typ = typ - elif isinstance(typ, BaseEnumMeta): - self.typ = lambda raw, _: typ.get(raw) - elif typ is None: - self.typ = lambda x, y: None - else: - self.typ = lambda raw, _: typ(raw) - - def serialize(self, value): - if isinstance(value, EnumAttr): - return value.value - elif isinstance(value, Model): - return value.to_dict() - else: - return value - - def try_convert(self, raw, client): - pass - - def __call__(self, raw, client): - return self.try_convert(raw, client) +class Field(object): + def __init__(self, value_type, alias=None, default=None): + self.src_name = alias + self.dst_name = None + if not hasattr(self, 'default'): + self.default = default -class Field(FieldType): - def __init__(self, typ, alias=None, default=None): - super(Field, self).__init__(typ) + self.deserializer = None - # Set names - self.src_name = alias - self.dst_name = None + if value_type: + self.deserializer = self.type_to_deserializer(value_type) - self.default = default + if isinstance(self.deserializer, Field): + self.default = self.deserializer.default - if isinstance(self.typ, FieldType): - self.default = self.typ.default + @property + def name(self): + return None - def set_name(self, name): + @name.setter + def name(self, name): if not self.dst_name: self.dst_name = name @@ -73,31 +56,68 @@ class Field(FieldType): def try_convert(self, raw, client): try: - return self.typ(raw, client) + return self.deserializer(raw, client) except Exception as e: - six.raise_from(ConversionError(self, raw, e), e) + exc_info = sys.exc_info() + raise ConversionError(self, raw, e), exc_info[1], exc_info[2] + @staticmethod + def type_to_deserializer(typ): + if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model): + return typ + elif isinstance(typ, BaseEnumMeta): + return lambda raw, _: typ.get(raw) + elif typ is None: + return lambda x, y: None + else: + return lambda raw, _: typ(raw) -class _Dict(FieldType): + @staticmethod + def serialize(value): + if isinstance(value, EnumAttr): + return value.value + elif isinstance(value, Model): + return value.to_dict() + else: + return value + + def __call__(self, raw, client): + return self.try_convert(raw, client) + + +class DictField(Field): default = HashMap - def __init__(self, typ, key=None): - super(_Dict, self).__init__(typ) - self.key = key + def __init__(self, key_type, value_type=None, **kwargs): + super(DictField, self).__init__(None, **kwargs) + self.key_de = self.type_to_deserializer(key_type) + self.value_de = self.type_to_deserializer(value_type or key_type) def try_convert(self, raw, client): - if self.key: - converted = [self.typ(i, client) for i in raw] - return HashMap({getattr(i, self.key): i for i in converted}) - else: - return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)}) + return HashMap({ + self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) + }) -class _List(FieldType): +class ListField(Field): default = list def try_convert(self, raw, client): - return [self.typ(i, client) for i in raw] + return [self.deserializer(i, client) for i in raw] + + +class AutoDictField(Field): + default = HashMap + + def __init__(self, value_type, key, **kwargs): + super(AutoDictField, self).__init__(None, **kwargs) + self.value_de = self.type_to_deserializer(value_type) + self.key = key + + def try_convert(self, raw, client): + return HashMap({ + getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw) + }) def _make(typ, data, client): @@ -116,14 +136,6 @@ def enum(typ): return _f -def listof(*args, **kwargs): - return _List(*args, **kwargs) - - -def dictof(*args, **kwargs): - return _Dict(*args, **kwargs) - - def lazy_datetime(data): if not data: return property(lambda: None) @@ -201,7 +213,7 @@ class ModelMeta(type): if not isinstance(v, Field): continue - v.set_name(k) + v.name = k fields[k] = v if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)): diff --git a/disco/types/channel.py b/disco/types/channel.py index 57d241d..97eeb38 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -5,7 +5,7 @@ from holster.enum import Enum from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property, one_or_many, chunks from disco.types.user import User -from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text +from disco.types.base import SlottedModel, Field, ListField, AutoDictField, snowflake, enum, text from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.voice.client import VoiceClient @@ -111,15 +111,18 @@ class Channel(SlottedModel, Permissible): last_message_id = Field(snowflake) position = Field(int) bitrate = Field(int) - recipients = Field(listof(User)) + recipients = ListField(User) type = Field(enum(ChannelType)) - overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') + overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') def __init__(self, *args, **kwargs): super(Channel, self).__init__(*args, **kwargs) self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) + def __str__(self): + return '#{}'.format(self.name) + def get_permissions(self, user): """ Get the permissions a user has in the channel diff --git a/disco/types/guild.py b/disco/types/guild.py index d8a2963..1f7ddd4 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -6,7 +6,9 @@ from disco.gateway.packets import OPCode from disco.api.http import APIException from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property -from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum +from disco.types.base import ( + SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum +) from disco.types.user import User, Presence from disco.types.voice import VoiceState from disco.types.channel import Channel @@ -45,7 +47,7 @@ class GuildEmoji(Emoji): name = Field(text) require_colons = Field(bool) managed = Field(bool) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) @cached_property def guild(self): @@ -128,7 +130,7 @@ class GuildMember(SlottedModel): mute = Field(bool) deaf = Field(bool) joined_at = Field(str) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) def __str__(self): return self.user.__str__() @@ -169,7 +171,10 @@ class GuildMember(SlottedModel): nickname : Optional[str] The nickname (or none to reset) to set. """ - self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') + if self.client.state.me.id == self.user.id: + self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '') + else: + self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') def add_role(self, role): roles = self.roles + [role.id] @@ -196,6 +201,10 @@ class GuildMember(SlottedModel): def guild(self): return self.client.state.guilds.get(self.guild_id) + @cached_property + def permissions(self): + return self.guild.get_permissions(self) + class Guild(SlottedModel, Permissible): """ @@ -252,14 +261,14 @@ class Guild(SlottedModel, Permissible): embed_enabled = Field(bool) verification_level = Field(enum(VerificationLevel)) mfa_level = Field(int) - features = Field(listof(str)) - members = Field(dictof(GuildMember, key='id')) - channels = Field(dictof(Channel, key='id')) - roles = Field(dictof(Role, key='id')) - emojis = Field(dictof(GuildEmoji, key='id')) - voice_states = Field(dictof(VoiceState, key='session_id')) + features = ListField(str) + members = AutoDictField(GuildMember, 'id') + channels = AutoDictField(Channel, 'id') + roles = AutoDictField(Role, 'id') + emojis = AutoDictField(GuildEmoji, 'id') + voice_states = AutoDictField(VoiceState, 'session_id') member_count = Field(int) - presences = Field(listof(Presence)) + presences = ListField(Presence) synced = Field(bool, default=False) @@ -272,7 +281,7 @@ class Guild(SlottedModel, Permissible): self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) - def get_permissions(self, user): + def get_permissions(self, member): """ Get the permissions a user has in this guild. @@ -281,10 +290,13 @@ class Guild(SlottedModel, Permissible): :class:`disco.types.permissions.PermissionValue` Computed permission value for the user. """ - if self.owner_id == user.id: + if not isinstance(member, GuildMember): + member = self.get_member(member) + + # Owner has all permissions + if self.owner_id == member.id: return PermissionValue(Permissions.ADMINISTRATOR) - member = self.get_member(user) value = PermissionValue(self.roles.get(self.id).permissions) for role in map(self.roles.get, member.roles): diff --git a/disco/types/message.py b/disco/types/message.py index ce9822f..e53ccb5 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -2,7 +2,10 @@ import re from holster.enum import Enum -from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum +from disco.types.base import ( + SlottedModel, Field, ListField, AutoDictField, snowflake, text, + lazy_datetime, enum +) from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -109,7 +112,7 @@ class MessageEmbed(SlottedModel): thumbnail = Field(MessageEmbedThumbnail) video = Field(MessageEmbedVideo) author = Field(MessageEmbedAuthor) - fields = Field(listof(MessageEmbedField)) + fields = ListField(MessageEmbedField) class MessageAttachment(SlottedModel): @@ -191,11 +194,11 @@ class Message(SlottedModel): tts = Field(bool) mention_everyone = Field(bool) pinned = Field(bool) - mentions = Field(dictof(User, key='id')) - mention_roles = Field(listof(snowflake)) - embeds = Field(listof(MessageEmbed)) - attachments = Field(dictof(MessageAttachment, key='id')) - reactions = Field(listof(MessageReaction)) + mentions = AutoDictField(User, 'id') + mention_roles = ListField(snowflake) + embeds = ListField(MessageEmbed) + attachments = AutoDictField(MessageAttachment, 'id') + reactions = ListField(MessageReaction) def __str__(self): return ''.format(self.id, self.channel_id) From d8d1df0ac4281582d7e144444bcf10ecea91a573 Mon Sep 17 00:00:00 2001 From: Michael Date: Sun, 30 Oct 2016 19:52:12 -0700 Subject: [PATCH 19/91] period at end of every docstring first line, ca -> can (#9) --- disco/api/ratelimit.py | 8 ++++---- disco/bot/bot.py | 6 +++--- disco/bot/command.py | 12 ++++++------ disco/bot/parser.py | 14 +++++++------- disco/bot/plugin.py | 28 ++++++++++++++-------------- disco/client.py | 6 +++--- disco/state.py | 2 +- disco/types/channel.py | 24 ++++++++++++------------ disco/types/guild.py | 10 +++++----- disco/types/invite.py | 2 +- disco/types/message.py | 8 ++++---- disco/util/token.py | 2 +- 12 files changed, 61 insertions(+), 61 deletions(-) diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py index 244b291..054c8cf 100644 --- a/disco/api/ratelimit.py +++ b/disco/api/ratelimit.py @@ -44,7 +44,7 @@ class RouteState(LoggingClass): @property def chilled(self): """ - Whether this route is currently being cooldown (aka waiting until reset_time) + Whether this route is currently being cooldown (aka waiting until reset_time). """ return self.event is not None @@ -74,7 +74,7 @@ class RouteState(LoggingClass): def wait(self, timeout=None): """ - Waits until this route is no longer under a cooldown + Waits until this route is no longer under a cooldown. Parameters ---------- @@ -85,13 +85,13 @@ class RouteState(LoggingClass): Returns ------- bool - False if the timeout period expired before the cooldown was finished + False if the timeout period expired before the cooldown was finished. """ return self.event.wait(timeout) def cooldown(self): """ - Waits for the current route to be cooled-down (aka waiting until reset time) + Waits for the current route to be cooled-down (aka waiting until reset time). """ if self.reset_time - time.time() < 0: raise Exception('Cannot cooldown for negative time period; check clock sync') diff --git a/disco/bot/bot.py b/disco/bot/bot.py index ba6ec25..369aea8 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -177,7 +177,7 @@ class Bot(object): @property def commands(self): """ - Generator of all commands this bots plugins have defined + Generator of all commands this bots plugins have defined. """ for plugin in six.itervalues(self.plugins): for command in six.itervalues(plugin.commands): @@ -194,7 +194,7 @@ class Bot(object): def compute_group_abbrev(self): """ - Computes all possible abbreviations for a command grouping + Computes all possible abbreviations for a command grouping. """ self.group_abbrev = {} groups = set(command.group for command in self.commands if command.group) @@ -417,7 +417,7 @@ class Bot(object): def run_forever(self): """ - Runs this bots core loop forever + Runs this bots core loop forever. """ self.client.run_forever() diff --git a/disco/bot/command.py b/disco/bot/command.py index 137a36c..5843046 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -48,28 +48,28 @@ class CommandEvent(object): @cached_property def member(self): """ - Guild member (if relevant) for the user that created the message + Guild member (if relevant) for the user that created the message. """ return self.guild.get_member(self.author) @property def channel(self): """ - Channel the message was created in + Channel the message was created in. """ return self.msg.channel @property def guild(self): """ - Guild (if relevant) the message was created in + Guild (if relevant) the message was created in. """ return self.msg.guild @property def author(self): """ - Author of the message + Author of the message. """ return self.msg.author @@ -159,14 +159,14 @@ class Command(object): @cached_property def compiled_regex(self): """ - A compiled version of this command's regex + A compiled version of this command's regex. """ return re.compile(self.regex) @property def regex(self): """ - The regex string that defines/triggers this command + The regex string that defines/triggers this command. """ if self.is_regex: return REGEX_FMT.format('|'.join(self.triggers)) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 8f3483e..fab4513 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -47,13 +47,13 @@ class Argument(object): @property def true_count(self): """ - The true number of raw arguments this argument takes + The true number of raw arguments this argument takes. """ return self.count or 1 def parse(self, raw): """ - Attempts to parse arguments from their raw form + Attempts to parse arguments from their raw form. """ prefix, part = raw @@ -78,7 +78,7 @@ class Argument(object): class ArgumentSet(object): """ - A set of :class:`Argument` instances which forms a larger argument specification + A set of :class:`Argument` instances which forms a larger argument specification. Attributes ---------- @@ -95,7 +95,7 @@ class ArgumentSet(object): @classmethod def from_string(cls, line, custom_types=None): """ - Creates a new :class:`ArgumentSet` from a given argument string specification + Creates a new :class:`ArgumentSet` from a given argument string specification. """ args = cls(custom_types=custom_types) @@ -131,7 +131,7 @@ class ArgumentSet(object): def append(self, arg): """ - Add a new :class:`Argument` to this argument specification/set + Add a new :class:`Argument` to this argument specification/set. """ if self.args and not self.args[-1].required and arg.required: raise Exception('Required argument cannot come after an optional argument') @@ -178,13 +178,13 @@ class ArgumentSet(object): @property def length(self): """ - The number of arguments in this set/specification + The number of arguments in this set/specification. """ return len(self.args) @property def required_length(self): """ - The number of required arguments to compile this set/specificaiton + The number of required arguments to compile this set/specificaiton. """ return sum([i.true_count for i in self.args if i.required]) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 6ea616e..10044c1 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -43,7 +43,7 @@ class PluginDeco(object): @classmethod def listen(cls, *args, **kwargs): """ - Binds the function to listen for a given event name + Binds the function to listen for a given event name. """ return cls.add_meta_deco({ 'type': 'listener', @@ -55,7 +55,7 @@ class PluginDeco(object): @classmethod def listen_packet(cls, *args, **kwargs): """ - Binds the function to listen for a given gateway op code + Binds the function to listen for a given gateway op code. """ return cls.add_meta_deco({ 'type': 'listener', @@ -67,7 +67,7 @@ class PluginDeco(object): @classmethod def command(cls, *args, **kwargs): """ - Creates a new command attached to the function + Creates a new command attached to the function. """ return cls.add_meta_deco({ 'type': 'command', @@ -78,7 +78,7 @@ class PluginDeco(object): @classmethod def pre_command(cls): """ - Runs a function before a command is triggered + Runs a function before a command is triggered. """ return cls.add_meta_deco({ 'type': 'pre_command', @@ -87,7 +87,7 @@ class PluginDeco(object): @classmethod def post_command(cls): """ - Runs a function after a command is triggered + Runs a function after a command is triggered. """ return cls.add_meta_deco({ 'type': 'post_command', @@ -96,7 +96,7 @@ class PluginDeco(object): @classmethod def pre_listener(cls): """ - Runs a function before a listener is triggered + Runs a function before a listener is triggered. """ return cls.add_meta_deco({ 'type': 'pre_listener', @@ -105,7 +105,7 @@ class PluginDeco(object): @classmethod def post_listener(cls): """ - Runs a function after a listener is triggered + Runs a function after a listener is triggered. """ return cls.add_meta_deco({ 'type': 'post_listener', @@ -114,7 +114,7 @@ class PluginDeco(object): @classmethod def schedule(cls, *args, **kwargs): """ - Runs a function repeatedly, waiting for a specified interval + Runs a function repeatedly, waiting for a specified interval. """ return cls.add_meta_deco({ 'type': 'schedule', @@ -212,7 +212,7 @@ class Plugin(LoggingClass, PluginDeco): def execute(self, event): """ - Executes a CommandEvent this plugin owns + Executes a CommandEvent this plugin owns. """ if not event.command.oob: self.greenlets.add(gevent.getcurrent()) @@ -226,7 +226,7 @@ class Plugin(LoggingClass, PluginDeco): def register_trigger(self, typ, when, func): """ - Registers a trigger + Registers a trigger. """ getattr(self, '_' + when)[typ].append(func) @@ -258,7 +258,7 @@ class Plugin(LoggingClass, PluginDeco): def register_listener(self, func, what, desc, **kwargs): """ - Registers a listener + Registers a listener. Parameters ---------- @@ -282,7 +282,7 @@ class Plugin(LoggingClass, PluginDeco): def register_command(self, func, *args, **kwargs): """ - Registers a command + Registers a command. Parameters ---------- @@ -331,13 +331,13 @@ class Plugin(LoggingClass, PluginDeco): def load(self, ctx): """ - Called when the plugin is loaded + Called when the plugin is loaded. """ pass def unload(self, ctx): """ - Called when the plugin is unloaded + Called when the plugin is unloaded. """ for greenlet in self.greenlets: greenlet.kill() diff --git a/disco/client.py b/disco/client.py index 6bda4ef..0de6f45 100644 --- a/disco/client.py +++ b/disco/client.py @@ -20,7 +20,7 @@ class ClientConfig(Config): Attributes ---------- token : str - Discord authentication token, ca be validated using the + Discord authentication token, can be validated using the :func:`disco.util.token.is_valid_token` function. shard_id : int The shard ID for the current client instance. @@ -123,12 +123,12 @@ class Client(LoggingClass): def run(self): """ - Run the client (e.g. the :class:`GatewayClient`) in a new greenlet + Run the client (e.g. the :class:`GatewayClient`) in a new greenlet. """ return gevent.spawn(self.gw.run) def run_forever(self): """ - Run the client (e.g. the :class:`GatewayClient`) in the current greenlet + Run the client (e.g. the :class:`GatewayClient`) in the current greenlet. """ return self.gw.run() diff --git a/disco/state.py b/disco/state.py index abe4cff..cf4636e 100644 --- a/disco/state.py +++ b/disco/state.py @@ -117,7 +117,7 @@ class State(object): def unbind(self): """ - Unbinds all bound event listeners for this state object + Unbinds all bound event listeners for this state object. """ map(lambda k: k.unbind(), self.listeners) self.listeners = [] diff --git a/disco/types/channel.py b/disco/types/channel.py index 97eeb38..f5389d1 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -33,7 +33,7 @@ class ChannelSubType(SlottedModel): class PermissionOverwrite(ChannelSubType): """ - A PermissionOverwrite for a :class:`Channel` + A PermissionOverwrite for a :class:`Channel`. Attributes ---------- @@ -81,7 +81,7 @@ class PermissionOverwrite(ChannelSubType): class Channel(SlottedModel, Permissible): """ - Represents a Discord Channel + Represents a Discord Channel. Attributes ---------- @@ -125,7 +125,7 @@ class Channel(SlottedModel, Permissible): def get_permissions(self, user): """ - Get the permissions a user has in the channel + Get the permissions a user has in the channel. Returns ------- @@ -154,42 +154,42 @@ class Channel(SlottedModel, Permissible): @property def is_guild(self): """ - Whether this channel belongs to a guild + Whether this channel belongs to a guild. """ return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE) @property def is_dm(self): """ - Whether this channel is a DM (does not belong to a guild) + Whether this channel is a DM (does not belong to a guild). """ return self.type in (ChannelType.DM, ChannelType.GROUP_DM) @property def is_voice(self): """ - Whether this channel supports voice + Whether this channel supports voice. """ return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM) @property def messages(self): """ - a default :class:`MessageIterator` for the channel + a default :class:`MessageIterator` for the channel. """ return self.messages_iter() @cached_property def guild(self): """ - Guild this channel belongs to (if relevant) + Guild this channel belongs to (if relevant). """ return self.client.state.guilds.get(self.guild_id) def messages_iter(self, **kwargs): """ Creates a new :class:`MessageIterator` for the channel with the given - keyword arguments + keyword arguments. """ return MessageIterator(self.client, self, **kwargs) @@ -232,7 +232,7 @@ class Channel(SlottedModel, Permissible): def send_message(self, content, nonce=None, tts=False): """ - Send a message in this channel + Send a message in this channel. Parameters ---------- @@ -252,7 +252,7 @@ class Channel(SlottedModel, Permissible): def connect(self, *args, **kwargs): """ - Connect to this channel over voice + Connect to this channel over voice. """ assert self.is_voice, 'Channel must support voice to connect' vc = VoiceClient(self) @@ -351,7 +351,7 @@ class MessageIterator(object): def fill(self): """ - Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API + Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API. """ self._buffer = self.client.api.channels_messages_list( self.channel.id, diff --git a/disco/types/guild.py b/disco/types/guild.py index 1f7ddd4..220ffa4 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -27,7 +27,7 @@ VerificationLevel = Enum( class GuildEmoji(Emoji): """ - An emoji object + An emoji object. Attributes ---------- @@ -56,7 +56,7 @@ class GuildEmoji(Emoji): class Role(SlottedModel): """ - A role object + A role object. Attributes ---------- @@ -105,7 +105,7 @@ class Role(SlottedModel): class GuildMember(SlottedModel): """ - A GuildMember object + A GuildMember object. Attributes ---------- @@ -193,7 +193,7 @@ class GuildMember(SlottedModel): @property def id(self): """ - Alias to the guild members user id + Alias to the guild members user id. """ return self.user.id @@ -208,7 +208,7 @@ class GuildMember(SlottedModel): class Guild(SlottedModel, Permissible): """ - A guild object + A guild object. Attributes ---------- diff --git a/disco/types/invite.py b/disco/types/invite.py index 850002e..906a360 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -6,7 +6,7 @@ from disco.types.channel import Channel class Invite(SlottedModel): """ - An invite object + An invite object. Attributes ---------- diff --git a/disco/types/message.py b/disco/types/message.py index e53ccb5..14115ef 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -88,7 +88,7 @@ class MessageEmbedField(SlottedModel): class MessageEmbed(SlottedModel): """ - Message embed object + Message embed object. Attributes ---------- @@ -117,7 +117,7 @@ class MessageEmbed(SlottedModel): class MessageAttachment(SlottedModel): """ - Message attachment object + Message attachment object. Attributes ---------- @@ -242,7 +242,7 @@ class Message(SlottedModel): def reply(self, *args, **kwargs): """ Reply to this message (proxys arguments to - :func:`disco.types.channel.Channel.send_message`) + :func:`disco.types.channel.Channel.send_message`). Returns ------- @@ -253,7 +253,7 @@ class Message(SlottedModel): def edit(self, content): """ - Edit this message + Edit this message. Args ---- diff --git a/disco/util/token.py b/disco/util/token.py index c48beca..d71b93d 100644 --- a/disco/util/token.py +++ b/disco/util/token.py @@ -5,6 +5,6 @@ TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}') def is_valid_token(token): """ - Validates a Discord authentication token, returning true if valid + Validates a Discord authentication token, returning true if valid. """ return bool(TOKEN_RE.match(token)) From c81044f82d9eb21b5f4294aed7229a9c9d0af1b0 Mon Sep 17 00:00:00 2001 From: andrei Date: Sun, 30 Oct 2016 21:46:57 -0700 Subject: [PATCH 20/91] Fix exception chaining, defaults, bump holster --- disco/types/base.py | 23 ++++++++++++++++++----- requirements.txt | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/disco/types/base.py b/disco/types/base.py index 3bea43a..0bce0c8 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -22,21 +22,27 @@ class ConversionError(Exception): 'Failed to convert `{}` (`{}`) to {}: {}'.format( str(raw)[:144], field.src_name, field.deserializer, e)) + if six.PY3: + self.__cause__ = e + class Field(object): - def __init__(self, value_type, alias=None, default=None): + def __init__(self, value_type, alias=None, default=None, test=0): self.src_name = alias self.dst_name = None + self.test = test - if not hasattr(self, 'default'): + if default is not None: self.default = default + elif not hasattr(self, 'default'): + self.default = None self.deserializer = None if value_type: self.deserializer = self.type_to_deserializer(value_type) - if isinstance(self.deserializer, Field): + if isinstance(self.deserializer, Field) and self.default is None: self.default = self.deserializer.default @property @@ -58,8 +64,13 @@ class Field(object): try: return self.deserializer(raw, client) except Exception as e: - exc_info = sys.exc_info() - raise ConversionError(self, raw, e), exc_info[1], exc_info[2] + err = ConversionError(self, raw, e) + + if six.PY2: + exc_info = sys.exc_info() + raise ConversionError, err, exc_info[2] + else: + raise err @staticmethod def type_to_deserializer(typ): @@ -132,6 +143,8 @@ def snowflake(data): def enum(typ): def _f(data): + if isinstance(data, str): + data = data.lower() return typ.get(data) if data is not None else None return _f diff --git a/requirements.txt b/requirements.txt index 9a1e8fd..a41d421 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ gevent==1.1.2 -holster==1.0.8 +holster==1.0.9 inflection==0.3.1 requests==2.11.1 six==1.10.0 From 92958309a832b124a8bf952460d66f37f3258cb7 Mon Sep 17 00:00:00 2001 From: andrei Date: Sun, 30 Oct 2016 22:01:49 -0700 Subject: [PATCH 21/91] Use six.reraise --- disco/types/base.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/disco/types/base.py b/disco/types/base.py index 0bce0c8..611672e 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -1,5 +1,4 @@ import six -import sys import gevent import inspect import functools @@ -64,13 +63,7 @@ class Field(object): try: return self.deserializer(raw, client) except Exception as e: - err = ConversionError(self, raw, e) - - if six.PY2: - exc_info = sys.exc_info() - raise ConversionError, err, exc_info[2] - else: - raise err + six.reraise(ConversionError, ConversionError(self, raw, e)) @staticmethod def type_to_deserializer(typ): From 9330d4a8dd627c83874d789dce4e3e817dd9af8a Mon Sep 17 00:00:00 2001 From: andrei Date: Wed, 2 Nov 2016 17:33:37 -0700 Subject: [PATCH 22/91] Couple of fixes --- disco/bot/bot.py | 2 ++ disco/types/base.py | 19 ++++++++++--------- disco/types/user.py | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 369aea8..f5282ae 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -431,6 +431,8 @@ class Bot(object): for entry in map(lambda i: getattr(mod, i), dir(mod)): if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin: + if getattr(entry, '_shallow', False) and Plugin in entry.__bases__: + continue loaded = True self.add_plugin(entry, config) diff --git a/disco/types/base.py b/disco/types/base.py index 611672e..17fde2e 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -142,19 +142,20 @@ def enum(typ): return _f +# TODO: make lazy def lazy_datetime(data): if not data: - return property(lambda: None) + return None - def get(): - for fmt in DATETIME_FORMATS: - try: - return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) - except (ValueError, TypeError): - continue - raise ValueError('Failed to conver `{}` to datetime'.format(data)) + if isinstance(data, int): + return real_datetime.utcfromtimestamp(data) - return property(get) + for fmt in DATETIME_FORMATS: + try: + return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) + except (ValueError, TypeError): + continue + raise ValueError('Failed to conver `{}` to datetime'.format(data)) def datetime(data): diff --git a/disco/types/user.py b/disco/types/user.py index 15507d9..a860843 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -8,7 +8,7 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): username = Field(text) avatar = Field(binary) discriminator = Field(str) - bot = Field(bool) + bot = Field(bool, default=False) verified = Field(bool) email = Field(str) From 137e4c1dcb5c3efee80e8bd9763482f6a5446fbb Mon Sep 17 00:00:00 2001 From: LewisHogan Date: Thu, 10 Nov 2016 06:36:39 +0000 Subject: [PATCH 23/91] Fixed minor typo in README for running example. (#10) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6c52569..ae77c31 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ class SimplePlugin(Plugin): Using the default bot configuration, we can now run this script like so: -`python -m disco.cli --token="MY_DISCORD_TOKEN" --bot --plugin simpleplugin` +`python -m disco.cli --token="MY_DISCORD_TOKEN" --run-bot --plugin simpleplugin` And commands can be triggered by mentioning the bot (configued by the BotConfig.command\_require\_mention flag): From 8db18a6a1febfb1491614dee2930386e597415cf Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 2 Nov 2016 19:33:51 -0500 Subject: [PATCH 24/91] Fix killing heartbeat task when it wasnt created, better loads/dumps --- disco/gateway/client.py | 5 ++++- disco/gateway/sharder.py | 20 +++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 46d5fa6..5ff6956 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -147,6 +147,8 @@ class GatewayClient(LoggingClass): raise Exception('WS recieved error: %s', error) def on_open(self): + self.log.info('Opened, headers: %s', self.ws.sock.headers) + if self.seq and self.session_id: self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.send(OPCode.RESUME, { @@ -175,7 +177,8 @@ class GatewayClient(LoggingClass): def on_close(self, code, reason): # Kill heartbeater, a reconnect/resume will trigger a HELLO which will # respawn it - self._heartbeat_task.kill() + if self._heartbeat_task: + self._heartbeat_task.kill() # If we're quitting, just break out of here if self.shutting_down: diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index 45e756a..0321cee 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -2,6 +2,7 @@ from __future__ import absolute_import import gipc import gevent +import pickle import logging import marshal @@ -82,7 +83,24 @@ class AutoSharder(object): format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) ) + @staticmethod + def dumps(data): + if isinstance(data, (basestring, int, long, bool, list, set, dict)): + return '\x01' + marshal.dumps(data) + elif isinstance(data, object) and data.__class__.__name__ == 'code': + return '\x01' + marshal.dumps(data) + else: + return '\x02' + pickle.dumps(data) + + @staticmethod + def loads(data): + enc_type = data[0] + if enc_type == '\x01': + return marshal.loads(data[1:]) + elif enc_type == '\x02': + return pickle.loads(data[1:]) + def start_shard(self, sid): - cpipe, ppipe = gipc.pipe(duplex=True, encoder=marshal.dumps, decoder=marshal.loads) + cpipe, ppipe = gipc.pipe(duplex=True, encoder=self.dumps, decoder=self.loads) gipc.start_process(run_shard, (self.config, sid, cpipe)) self.shards[sid] = GIPCProxy(self, ppipe) From 28cfa830f7ec6416bfe7cc524d72cb137c11420a Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 7 Nov 2016 16:00:00 -0600 Subject: [PATCH 25/91] Various fixes and tweaks - Add plugin context on load - Allow providing additional kwargs to a command context - Commands are now stored using their base trigger instead of function name, which allows for multiple bindings on a single function - Use argument passing for register_listener - Fix RedisProvider.get_many failing when its passed an empty list of keys - Add some utility properties on MessageReactionAdd/Remove - Support type conversion toggling in MessageTable - Latest holster fixes --- disco/bot/bot.py | 3 ++- disco/bot/command.py | 6 ++++-- disco/bot/plugin.py | 16 +++++++++------- disco/bot/providers/redis.py | 4 ++++ disco/gateway/events.py | 16 ++++++++++++++++ disco/types/message.py | 3 ++- disco/types/permissions.py | 4 ++-- 7 files changed, 39 insertions(+), 13 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index f5282ae..428f76a 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -383,9 +383,10 @@ class Bot(object): else: config = self.load_plugin_config(cls) - self.plugins[cls.__name__] = cls(self, config) + self.ctx['plugin'] = self.plugins[cls.__name__] = cls(self, config) self.plugins[cls.__name__].load(ctx or {}) self.recompute() + self.ctx.drop() def rmv_plugin(self, cls): """ diff --git a/disco/bot/command.py b/disco/bot/command.py index 5843046..5ef873f 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -113,10 +113,11 @@ class Command(object): self.group = None self.is_regex = None self.oob = False + self.context = {} self.update(*args, **kwargs) - def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False): + def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None): self.triggers += aliases or [] def resolve_role(ctx, rid): @@ -135,6 +136,7 @@ class Command(object): self.group = group self.is_regex = is_regex self.oob = oob + self.context = context or {} @staticmethod def mention_type(getters, force=False): @@ -201,4 +203,4 @@ class Command(object): except ArgumentError as e: raise CommandError(e.message) - return self.func(event, *args) + return self.func(event, *args, **self.context) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 10044c1..3764f38 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -256,7 +256,7 @@ class Plugin(LoggingClass, PluginDeco): return True - def register_listener(self, func, what, desc, **kwargs): + def register_listener(self, func, what, *args, **kwargs): """ Registers a listener. @@ -269,12 +269,12 @@ class Plugin(LoggingClass, PluginDeco): desc The descriptor of the event/packet. """ - func = functools.partial(self._dispatch, 'listener', func) + args = list(args) + [functools.partial(self._dispatch, 'listener', func)] if what == 'event': - li = self.bot.client.events.on(desc, func, **kwargs) + li = self.bot.client.events.on(*args, **kwargs) elif what == 'packet': - li = self.bot.client.packets.on(desc, func, **kwargs) + li = self.bot.client.packets.on(*args, **kwargs) else: raise Exception('Invalid listener what: {}'.format(what)) @@ -294,11 +294,13 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - if kwargs.pop('update', False) and func.__name__ in self.commands: - self.commands[func.__name__].update(*args, **kwargs) + name = args[0] + + if kwargs.pop('update', False) and name in self.commands: + self.commands[name].update(*args, **kwargs) else: wrapped = functools.partial(self._dispatch, 'command', func) - self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs) + self.commands[name] = Command(self, wrapped, *args, **kwargs) def register_schedule(self, func, interval, repeat=True, init=True): """ diff --git a/disco/bot/providers/redis.py b/disco/bot/providers/redis.py index 1e5150a..f5e1375 100644 --- a/disco/bot/providers/redis.py +++ b/disco/bot/providers/redis.py @@ -31,6 +31,10 @@ class RedisProvider(BaseProvider): yield key def get_many(self, keys): + keys = list(keys) + if not len(keys): + raise StopIteration + for key, value in izip(keys, self.conn.mget(keys)): yield (key, Serializer.loads(self.format, value)) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6799795..3e8af68 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -591,6 +591,14 @@ class MessageReactionAdd(GatewayEvent): user_id = Field(snowflake) emoji = Field(MessageReactionEmoji) + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + class MessageReactionRemove(GatewayEvent): """ @@ -611,3 +619,11 @@ class MessageReactionRemove(GatewayEvent): message_id = Field(snowflake) user_id = Field(snowflake) emoji = Field(MessageReactionEmoji) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild diff --git a/disco/types/message.py b/disco/types/message.py index 14115ef..b5eaf17 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -371,7 +371,8 @@ class MessageTable(object): self.recalculate_size_index(args) def add(self, *args): - args = list(map(str, args)) + convert = lambda v: v if isinstance(v, basestring) else str(v) + args = list(map(convert, args)) self.entries.append(args) self.recalculate_size_index(args) diff --git a/disco/types/permissions.py b/disco/types/permissions.py index 7afb3f7..ff43145 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -76,13 +76,13 @@ class PermissionValue(object): return self.sub(other) def __getattribute__(self, name): - if name in Permissions.attrs: + if name in Permissions.keys_: return (self.value & Permissions[name].value) == Permissions[name].value else: return object.__getattribute__(self, name) def __setattr__(self, name, value): - if name not in Permissions.attrs: + if name not in Permissions.keys_: return super(PermissionValue, self).__setattr__(name, value) if value: From 58ea923562b3decd907b0dfbc6356d32678d0a7d Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 7 Nov 2016 16:17:07 -0600 Subject: [PATCH 26/91] Better command function introspection, docstrings --- disco/bot/command.py | 9 +++++++-- disco/bot/plugin.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index 5ef873f..e8ef431 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -108,6 +108,7 @@ class Command(object): self.func = func self.triggers = [trigger] + self.dispatch_func = None self.args = None self.level = None self.group = None @@ -117,7 +118,10 @@ class Command(object): self.update(*args, **kwargs) - def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None): + def get_docstring(self): + return (self.func.__doc__ or '').format(**self.context) + + def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None, dispatch_func=None): self.triggers += aliases or [] def resolve_role(ctx, rid): @@ -137,6 +141,7 @@ class Command(object): self.is_regex = is_regex self.oob = oob self.context = context or {} + self.dispatch_func = dispatch_func @staticmethod def mention_type(getters, force=False): @@ -203,4 +208,4 @@ class Command(object): except ArgumentError as e: raise CommandError(e.message) - return self.func(event, *args, **self.context) + return (self.dispatch_func or self.func)(event, *args, **self.context) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 3764f38..0e05136 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -300,7 +300,8 @@ class Plugin(LoggingClass, PluginDeco): self.commands[name].update(*args, **kwargs) else: wrapped = functools.partial(self._dispatch, 'command', func) - self.commands[name] = Command(self, wrapped, *args, **kwargs) + kwargs.setdefault('dispatch_func', wrapped) + self.commands[name] = Command(self, func, *args, **kwargs) def register_schedule(self, func, interval, repeat=True, init=True): """ From b5284c19756bdc40d49c1c88b50f360996b64985 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 10 Nov 2016 19:28:14 -0600 Subject: [PATCH 27/91] Various fixes and improvements - Add support for attachments and message embeds - Fix commands being weirdly stored by some key (which doesn't make sense) - Added CommandEvent.codeblock which represents the first codeblock in the message (useful for eval like commands) - Cleanup the spawn utilties on plugin a bit - Fix GuildBanAdd/GuildBanRemove - Unset model fields are now a special sentinel value - etc stuff --- disco/api/client.py | 16 +++++++++++++--- disco/bot/bot.py | 2 +- disco/bot/command.py | 26 ++++++++++++++++++++----- disco/bot/plugin.py | 42 +++++++++++++++++++++++++++-------------- disco/gateway/events.py | 20 +++++++++++++------- disco/state.py | 18 ++++++++++++++---- disco/types/__init__.py | 1 + disco/types/base.py | 28 +++++++++++++++++++++------ disco/types/channel.py | 9 ++++++--- disco/types/user.py | 18 ++++++++++++------ 10 files changed, 131 insertions(+), 49 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index 27152e4..3a15171 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -77,12 +77,22 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) return Message.create(self.client, r.json()) - def channels_messages_create(self, channel, content, nonce=None, tts=False): - r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={ + def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None): + payload = { 'content': content, 'nonce': nonce, 'tts': tts, - }) + } + + if embed: + payload['embed'] = embed.to_dict() + + if attachment: + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data=payload, files={ + 'file': (attachment[0], attachment[1]) + }) + else: + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload) return Message.create(self.client, r.json()) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 428f76a..ec670e1 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -180,7 +180,7 @@ class Bot(object): Generator of all commands this bots plugins have defined. """ for plugin in six.itervalues(self.plugins): - for command in six.itervalues(plugin.commands): + for command in plugin.commands: yield command def recompute(self): diff --git a/disco/bot/command.py b/disco/bot/command.py index e8ef431..40ccbea 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -45,6 +45,18 @@ class CommandEvent(object): self.name = self.match.group(1) self.args = [i for i in self.match.group(2).strip().split(' ') if i] + @property + def codeblock(self): + _, src = self.msg.content.split('`', 1) + src = '`' + src + + if src.startswith('```') and src.endswith('```'): + src = src[3:-3] + elif src.startswith('`') and src.endswith('`'): + src = src[1:-1] + + return src + @cached_property def member(self): """ @@ -146,11 +158,15 @@ class Command(object): @staticmethod def mention_type(getters, force=False): def _f(ctx, i): - res = MENTION_RE.match(i) - if not res: - raise TypeError('Invalid mention: {}'.format(i)) - - mid = int(res.group(1)) + # TODO: support full discrim format? make this betteR? + if i.isdigit(): + mid = int(i) + else: + res = MENTION_RE.match(i) + if not res: + raise TypeError('Invalid mention: {}'.format(i)) + + mid = int(res.group(1)) for getter in getters: obj = getter(ctx, mid) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 0e05136..fcbc9ad 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -156,7 +156,7 @@ class Plugin(LoggingClass, PluginDeco): # General declartions self.listeners = [] - self.commands = {} + self.commands = [] self.schedules = {} self.greenlets = weakref.WeakSet() self._pre = {} @@ -182,7 +182,7 @@ class Plugin(LoggingClass, PluginDeco): def bind_all(self): self.listeners = [] - self.commands = {} + self.commands = [] self.schedules = {} self.greenlets = weakref.WeakSet() @@ -197,7 +197,7 @@ class Plugin(LoggingClass, PluginDeco): if meta['type'] == 'listener': self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs']) elif meta['type'] == 'command': - meta['kwargs']['update'] = True + # meta['kwargs']['update'] = True self.register_command(member, *meta['args'], **meta['kwargs']) elif meta['type'] == 'schedule': self.register_schedule(member, *meta['args'], **meta['kwargs']) @@ -205,11 +205,25 @@ class Plugin(LoggingClass, PluginDeco): when, typ = meta['type'].split('_', 1) self.register_trigger(typ, when, member) - def spawn(self, method, *args, **kwargs): - obj = gevent.spawn(method, *args, **kwargs) + def spawn_wrap(self, spawner, method, *args, **kwargs): + def wrapped(*args, **kwargs): + self.ctx['plugin'] = self + try: + res = method(*args, **kwargs) + return res + finally: + self.ctx.drop() + + obj = spawner(wrapped, *args, **kwargs) self.greenlets.add(obj) return obj + def spawn(self, *args, **kwargs): + return self.spawn_wrap(gevent.spawn, *args, **kwargs) + + def spawn_later(self, delay, *args, **kwargs): + return self.spawn_wrap(functools.partial(gevent.spawn_later, delay), *args, **kwargs) + def execute(self, event): """ Executes a CommandEvent this plugin owns. @@ -294,14 +308,14 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - name = args[0] + # name = args[0] - if kwargs.pop('update', False) and name in self.commands: - self.commands[name].update(*args, **kwargs) - else: - wrapped = functools.partial(self._dispatch, 'command', func) - kwargs.setdefault('dispatch_func', wrapped) - self.commands[name] = Command(self, func, *args, **kwargs) + # if kwargs.pop('update', False) and name in self.commands: + # self.commands[name].update(*args, **kwargs) + # else: + wrapped = functools.partial(self._dispatch, 'command', func) + kwargs.setdefault('dispatch_func', wrapped) + self.commands.append(Command(self, func, *args, **kwargs)) def register_schedule(self, func, interval, repeat=True, init=True): """ @@ -320,7 +334,7 @@ class Plugin(LoggingClass, PluginDeco): Whether to run this schedule once immediatly, or wait for the first scheduled iteration. """ - def func(): + def repeat_func(): if init: func() @@ -330,7 +344,7 @@ class Plugin(LoggingClass, PluginDeco): if not repeat: break - self.schedules[func.__name__] = self.spawn(repeat) + self.schedules[func.__name__] = self.spawn(repeat_func) def load(self, ctx): """ diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 3e8af68..5fc28cf 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -67,15 +67,16 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): return object.__getattribute__(self, name) -def debug(func=None): +def debug(func=None, match=None): def deco(cls): old_init = cls.__init__ def new_init(self, obj, *args, **kwargs): - if func: - print(func(obj)) - else: - print(obj) + if not match or match(obj): + if func: + print(func(obj)) + else: + print(obj) old_init(self, obj, *args, **kwargs) @@ -244,7 +245,7 @@ class ChannelPinsUpdate(GatewayEvent): last_pin_timestamp = Field(lazy_datetime) -@wraps_model(User) +@proxy(User) class GuildBanAdd(GatewayEvent): """ Sent when a user is banned from a guild. @@ -257,13 +258,14 @@ class GuildBanAdd(GatewayEvent): The user being banned from the guild. """ guild_id = Field(snowflake) + user = Field(User) @property def guild(self): return self.client.state.guilds.get(self.guild_id) -@wraps_model(User) +@proxy(User) class GuildBanRemove(GuildBanAdd): """ Sent when a user is unbanned from a guild. @@ -507,6 +509,10 @@ class PresenceUpdate(GatewayEvent): guild_id = Field(snowflake) roles = ListField(snowflake) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + class TypingStart(GatewayEvent): """ diff --git a/disco/state.py b/disco/state.py index cf4636e..81efbb3 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,8 +1,8 @@ import six +import weakref import inflection from collections import deque, namedtuple -from weakref import WeakValueDictionary from gevent.event import Event from disco.util.config import Config @@ -102,9 +102,9 @@ class State(object): self.me = None self.dms = HashMap() self.guilds = HashMap() - self.channels = HashMap(WeakValueDictionary()) - self.users = HashMap(WeakValueDictionary()) - self.voice_states = HashMap(WeakValueDictionary()) + self.channels = HashMap(weakref.WeakValueDictionary()) + self.users = HashMap(weakref.WeakValueDictionary()) + self.voice_states = HashMap(weakref.WeakValueDictionary()) # If message tracking is enabled, listen to those events if self.config.track_messages: @@ -298,4 +298,14 @@ class State(object): def on_presence_update(self, event): if event.user.id in self.users: + self.users[event.user.id].update(event.presence.user) self.users[event.user.id].presence = event.presence + event.presence.user = self.users[event.user.id] + + if event.guild_id not in self.guilds: + return + + if event.user.id not in self.guilds[event.guild_id].members: + return + + self.guilds[event.guild_id].members[event.user.id].user.update(event.user) diff --git a/disco/types/__init__.py b/disco/types/__init__.py index 5e6f73b..5824ec5 100644 --- a/disco/types/__init__.py +++ b/disco/types/__init__.py @@ -1,3 +1,4 @@ +from disco.types.base import UNSET from disco.types.channel import Channel from disco.types.guild import Guild, GuildMember, Role from disco.types.user import User diff --git a/disco/types/base.py b/disco/types/base.py index 17fde2e..1013906 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -15,6 +15,14 @@ DATETIME_FORMATS = [ ] +class Unset(object): + def __nonzero__(self): + return False + + +UNSET = Unset() + + class ConversionError(Exception): def __init__(self, field, raw, e): super(ConversionError, self).__init__( @@ -26,10 +34,9 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None, test=0): + def __init__(self, value_type, alias=None, default=None): self.src_name = alias self.dst_name = None - self.test = test if default is not None: self.default = default @@ -97,6 +104,10 @@ class DictField(Field): self.key_de = self.type_to_deserializer(key_type) self.value_de = self.type_to_deserializer(value_type or key_type) + @staticmethod + def serialize(value): + return {Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)} + def try_convert(self, raw, client): return HashMap({ self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) @@ -106,6 +117,10 @@ class DictField(Field): class ListField(Field): default = list + @staticmethod + def serialize(value): + return list(map(Field.serialize, value)) + def try_convert(self, raw, client): return [self.deserializer(i, client) for i in raw] @@ -265,7 +280,7 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): if field.has_default(): default = field.default() if callable(field.default) else field.default else: - default = None + default = UNSET setattr(self, field.dst_name, default) continue @@ -274,9 +289,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): def update(self, other): for name in six.iterkeys(self.__class__._fields): - value = getattr(other, name) - if value: - setattr(self, name, value) + if hasattr(other, name) and not getattr(other, name) is UNSET: + setattr(self, name, getattr(other, name)) # Clear cached properties for name in dir(type(self)): @@ -289,6 +303,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): def to_dict(self): obj = {} for name, field in six.iteritems(self.__class__._fields): + if getattr(self, name) == UNSET: + continue obj[name] = field.serialize(getattr(self, name)) return obj diff --git a/disco/types/channel.py b/disco/types/channel.py index f5389d1..3912535 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -121,7 +121,10 @@ class Channel(SlottedModel, Permissible): self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) def __str__(self): - return '#{}'.format(self.name) + return u'#{}'.format(self.name) + + def __repr__(self): + return u''.format(self.id, self) def get_permissions(self, user): """ @@ -230,7 +233,7 @@ class Channel(SlottedModel, Permissible): def create_webhook(self, name=None, avatar=None): return self.client.api.channels_webhooks_create(self.id, name, avatar) - def send_message(self, content, nonce=None, tts=False): + def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None): """ Send a message in this channel. @@ -248,7 +251,7 @@ class Channel(SlottedModel, Permissible): :class:`disco.types.message.Message` The created message. """ - return self.client.api.channels_messages_create(self.id, content, nonce, tts) + return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed) def connect(self, *args, **kwargs): """ diff --git a/disco/types/user.py b/disco/types/user.py index a860843..c2bac79 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -14,18 +14,24 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): presence = Field(None) + @property + def avatar_url(self): + if not self.avatar: + return None + + return 'https://discordapp.com/api/users/{}/avatars/{}.jpg'.format( + self.id, + self.avatar) + @property def mention(self): return '<@{}>'.format(self.id) def __str__(self): - return '{}#{}'.format(self.username, self.discriminator) + return u'{}#{}'.format(self.username, self.discriminator) def __repr__(self): - return ''.format(self.id, self.to_string()) - - def on_create(self): - self.client.state.users[self.id] = self + return u''.format(self.id, self) GameType = Enum( @@ -49,6 +55,6 @@ class Game(SlottedModel): class Presence(SlottedModel): - user = Field(User) + user = Field(User, alias='user') game = Field(Game) status = Field(Status) From dd75502b89a5ef629182a2493101d9fac5e9a810 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 12 Nov 2016 07:22:19 -0600 Subject: [PATCH 28/91] Improvements to command processing - Bot.get_commands_for_message is now more composable/externally useable - Command parser for 'user' type has been improved to allow username/discrim combos - Model loading can now be done outside of the model constructor, and supports some utility arguments - Fix sub-model fields not having their default value be the sub-model constructor - Fix Message.without_mentions - Add Message.with_proper_mentions (e.g. humanifying the message) - Cleanup Message.replace_mentions for the above two changes - Fix some weird casting inside MessageTable --- disco/bot/bot.py | 45 +++++++++++++++++--------------- disco/bot/command.py | 31 ++++++++++++---------- disco/types/base.py | 27 ++++++++++++++++---- disco/types/message.py | 58 +++++++++++++++++++++++++++++++----------- 4 files changed, 108 insertions(+), 53 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index ec670e1..2f97832 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -221,7 +221,7 @@ class Bot(object): else: self.command_matches_re = None - def get_commands_for_message(self, msg): + def get_commands_for_message(self, require_mention, mention_rules, prefix, msg): """ Generator of all commands that a given message object triggers, based on the bots plugins and configuration. @@ -238,7 +238,7 @@ class Bot(object): """ content = msg.content - if self.config.commands_require_mention: + if require_mention: mention_direct = msg.is_mentioned(self.client.state.me) mention_everyone = msg.mention_everyone @@ -248,9 +248,9 @@ class Bot(object): msg.guild.get_member(self.client.state.me).roles)) if not any(( - self.config.commands_mention_rules['user'] and mention_direct, - self.config.commands_mention_rules['everyone'] and mention_everyone, - self.config.commands_mention_rules['role'] and any(mention_roles), + mention_rules.get('user', True) and mention_direct, + mention_rules.get('everyone', False) and mention_everyone, + mention_rules.get('role', False) and any(mention_roles), msg.channel.is_dm )): raise StopIteration @@ -270,10 +270,10 @@ class Bot(object): content = content.lstrip() - if self.config.commands_prefix and not content.startswith(self.config.commands_prefix): + if prefix and not content.startswith(prefix): raise StopIteration else: - content = content[len(self.config.commands_prefix):] + content = content[len(prefix):] if not self.command_matches_re or not self.command_matches_re.match(content): raise StopIteration @@ -324,19 +324,24 @@ class Bot(object): bool whether any commands where successfully triggered by the message """ - commands = list(self.get_commands_for_message(msg)) - - if len(commands): - result = False - for command, match in commands: - if not self.check_command_permissions(command, msg): - continue - - if command.plugin.execute(CommandEvent(command, msg, match)): - result = True - return result - - return False + commands = list(self.get_commands_for_message( + self.config.commands_require_mention, + self.config.commands_mention_rules, + self.config.commands_prefix, + msg + )) + + if not len(commands): + return False + + result = False + for command, match in commands: + if not self.check_command_permissions(command, msg): + continue + + if command.plugin.execute(CommandEvent(command, msg, match)): + result = True + return result def on_message_create(self, event): if event.message.author.id == self.client.state.me.id: diff --git a/disco/bot/command.py b/disco/bot/command.py index 40ccbea..08c9b9f 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -140,11 +140,14 @@ class Command(object): return ctx.msg.guild.roles.get(rid) def resolve_user(ctx, uid): - return ctx.msg.mentions.get(uid) + if isinstance(uid, int): + return ctx.msg.mentions.get(uid) + else: + return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1]) self.args = ArgumentSet.from_string(args or '', { 'mention': self.mention_type([resolve_role, resolve_user]), - 'user': self.mention_type([resolve_user], force=True), + 'user': self.mention_type([resolve_user], force=True, user=True), 'role': self.mention_type([resolve_role], force=True), }) @@ -156,27 +159,29 @@ class Command(object): self.dispatch_func = dispatch_func @staticmethod - def mention_type(getters, force=False): - def _f(ctx, i): - # TODO: support full discrim format? make this betteR? - if i.isdigit(): - mid = int(i) + def mention_type(getters, force=False, user=False): + def _f(ctx, raw): + if raw.isdigit(): + resolved = int(raw) + elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit(): + username, discrim = raw.split('#') + resolved = (username, int(discrim)) else: - res = MENTION_RE.match(i) + res = MENTION_RE.match(raw) if not res: - raise TypeError('Invalid mention: {}'.format(i)) + raise TypeError('Invalid mention: {}'.format(raw)) - mid = int(res.group(1)) + resolved = int(res.group(1)) for getter in getters: - obj = getter(ctx, mid) + obj = getter(ctx, resolved) if obj: return obj if force: - raise TypeError('Cannot resolve mention: {}'.format(id)) + raise TypeError('Cannot resolve mention: {}'.format(raw)) - return mid + return resolved return _f @cached_property diff --git a/disco/types/base.py b/disco/types/base.py index 1013906..56657ff 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -34,9 +34,10 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None): + def __init__(self, value_type, alias=None, default=None, **kwargs): self.src_name = alias self.dst_name = None + self.metadata = kwargs if default is not None: self.default = default @@ -50,6 +51,8 @@ class Field(object): if isinstance(self.deserializer, Field) and self.default is None: self.default = self.deserializer.default + elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None: + self.default = self.deserializer @property def name(self): @@ -275,8 +278,22 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): else: obj = kwargs - for name, field in six.iteritems(self.__class__._fields): - if field.src_name not in obj or obj[field.src_name] is None: + self.load(obj) + + @property + def fields(self): + return self.__class__._fields + + def load(self, obj, consume=False, skip=None): + for name, field in six.iteritems(self.fields): + should_skip = skip and name in skip + + if consume and not should_skip: + raw = obj.pop(field.src_name, None) + else: + raw = obj.get(field.src_name, None) + + if raw is None or should_skip: if field.has_default(): default = field.default() if callable(field.default) else field.default else: @@ -284,11 +301,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): setattr(self, field.dst_name, default) continue - value = field.try_convert(obj[field.src_name], self.client) + value = field.try_convert(raw, self.client) setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self.__class__._fields): + for name in six.iterkeys(self.fields): if hasattr(other, name) and not getattr(other, name) is UNSET: setattr(self, name, getattr(other, name)) diff --git a/disco/types/message.py b/disco/types/message.py index b5eaf17..e7b2873 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,4 +1,5 @@ import re +import functools from holster.enum import Enum @@ -310,18 +311,33 @@ class Message(SlottedModel): return entity in self.mentions or entity in self.mention_roles @cached_property - def without_mentions(self): + def without_mentions(self, valid_only=False): """ Returns ------- str - the message contents with all valid mentions removed. + the message contents with all mentions removed. """ return self.replace_mentions( lambda u: '', - lambda r: '') + lambda r: '', + lambda c: '', + nonexistant=not valid_only) - def replace_mentions(self, user_replace, role_replace): + @cached_property + def with_proper_mentions(self): + def replace_user(u): + return '@' + str(u) + + def replace_role(r): + return '@' + str(r) + + def replace_channel(c): + return str(c) + + return self.replace_mentions(replace_user, replace_role, replace_channel) + + def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False): """ Replaces user and role mentions with the result of a given lambda/function. @@ -339,17 +355,30 @@ class Message(SlottedModel): str The message contents with all valid mentions replaced. """ - if not self.mentions and not self.mention_roles: - return + def replace(getter, func, match): + oid = int(match.group(2)) + obj = getter(oid) + + if obj or nonexistant: + return func(obj or oid) or match.group(0) + + return match.group(0) + + content = self.content + + if user_replace: + replace_user = functools.partial(replace, self.mentions.get, user_replace) + content = re.sub('(<@!?([0-9]+)>)', replace_user, self.content) + + if role_replace: + replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace) + content = re.sub('(<@&([0-9]+)>)', replace_role, content) - def replace(match): - oid = match.group(0) - if oid in self.mention_roles: - return role_replace(oid) - else: - return user_replace(self.mentions.get(oid)) + if channel_replace: + replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace) + content = re.sub('(<#([0-9]+)>)', replace_channel, content) - return re.sub('<@!?([0-9]+)>', replace, self.content) + return content class MessageTable(object): @@ -371,8 +400,7 @@ class MessageTable(object): self.recalculate_size_index(args) def add(self, *args): - convert = lambda v: v if isinstance(v, basestring) else str(v) - args = list(map(convert, args)) + args = list(map(lambda v: v if isinstance(v, basestring) else str(v), args)) self.entries.append(args) self.recalculate_size_index(args) From 07321c925008d9619d559b5108bb95cff63ad3be Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 12 Nov 2016 17:23:15 -0600 Subject: [PATCH 29/91] Rename Model.fields to avoid smashing other model props, etc --- disco/bot/command.py | 23 +++++++++++++---------- disco/types/base.py | 7 ++++--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index 08c9b9f..1cf00a1 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -7,7 +7,10 @@ from disco.util.functional import cached_property REGEX_FMT = '({})' ARGS_REGEX = '( ((?:\n|.)*)$|$)' -MENTION_RE = re.compile('<@!?([0-9]+)>') + +USER_MENTION_RE = re.compile('<@!?([0-9]+)>') +ROLE_MENTION_RE = re.compile('<@&([0-9]+)>') +CHANNEL_MENTION_RE = re.compile('<#([0-9]+)>') CommandLevels = Enum( DEFAULT=0, @@ -145,10 +148,13 @@ class Command(object): else: return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1]) + def resolve_channel(ctx, cid): + return ctx.msg.guild.channels.get(cid) + self.args = ArgumentSet.from_string(args or '', { - 'mention': self.mention_type([resolve_role, resolve_user]), - 'user': self.mention_type([resolve_user], force=True, user=True), - 'role': self.mention_type([resolve_role], force=True), + 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True), + 'role': self.mention_type([resolve_role], ROLE_MENTION_RE), + 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE), }) self.level = level @@ -159,7 +165,7 @@ class Command(object): self.dispatch_func = dispatch_func @staticmethod - def mention_type(getters, force=False, user=False): + def mention_type(getters, reg, user=False): def _f(ctx, raw): if raw.isdigit(): resolved = int(raw) @@ -167,7 +173,7 @@ class Command(object): username, discrim = raw.split('#') resolved = (username, int(discrim)) else: - res = MENTION_RE.match(raw) + res = reg.match(raw) if not res: raise TypeError('Invalid mention: {}'.format(raw)) @@ -178,10 +184,7 @@ class Command(object): if obj: return obj - if force: - raise TypeError('Cannot resolve mention: {}'.format(raw)) - - return resolved + raise TypeError('Cannot resolve mention: {}'.format(raw)) return _f @cached_property diff --git a/disco/types/base.py b/disco/types/base.py index 56657ff..dfa1e2c 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -73,6 +73,7 @@ class Field(object): try: return self.deserializer(raw, client) except Exception as e: + raise six.reraise(ConversionError, ConversionError(self, raw, e)) @staticmethod @@ -281,11 +282,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): self.load(obj) @property - def fields(self): + def _fields(self): return self.__class__._fields def load(self, obj, consume=False, skip=None): - for name, field in six.iteritems(self.fields): + for name, field in six.iteritems(self._fields): should_skip = skip and name in skip if consume and not should_skip: @@ -305,7 +306,7 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self.fields): + for name in six.iterkeys(self._fields): if hasattr(other, name) and not getattr(other, name) is UNSET: setattr(self, name, getattr(other, name)) From 6f8cecdcf70bbc6836b81c686ff56bd29805e5fa Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 12 Nov 2016 17:29:06 -0600 Subject: [PATCH 30/91] Fix filtered args dict compt in HTTPClient --- disco/api/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/api/http.py b/disco/api/http.py index 31035e4..951a143 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -203,7 +203,7 @@ class HTTPClient(LoggingClass): # Build the bucket URL args = {to_bytes(k): to_bytes(v) for k, v in six.iteritems(args)} - filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in six.iteritems(args)} + filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)} bucket = (route[0].value, route[1].format(**filtered)) # Possibly wait if we're rate limited From bc8e494fba192768168e698eaacc82edcb2733cb Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 12 Nov 2016 21:06:23 -0600 Subject: [PATCH 31/91] bugfix - don't to_bytes keys in args --- disco/api/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/api/http.py b/disco/api/http.py index 951a143..2ebf941 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -202,7 +202,7 @@ class HTTPClient(LoggingClass): kwargs['headers'] = self.headers # Build the bucket URL - args = {to_bytes(k): to_bytes(v) for k, v in six.iteritems(args)} + args = {k: to_bytes(v) for k, v in six.iteritems(args)} filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)} bucket = (route[0].value, route[1].format(**filtered)) From fa2f915de5c906a0dad9906a253ec95cd91549b1 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sun, 13 Nov 2016 01:37:04 -0600 Subject: [PATCH 32/91] Remove some debug stuff --- disco/api/client.py | 18 ++++++++++++++---- disco/bot/bot.py | 2 +- disco/bot/command.py | 3 +++ disco/types/message.py | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index 3a15171..c8fe702 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -23,16 +23,26 @@ def optional(**kwargs): class APIClient(LoggingClass): """ - An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits - the models with the returned data. + An abstraction over a :class:`disco.api.http.HTTPClient`, which composes + requests from provided data, and fits models with the returned data. The APIClient + is the only path to the API used within models/other interfaces, and it's + the recommended path for all third-party users/implementations. Args ---- token : str The Discord authentication token (without prefixes) to be used for all HTTP requests. - client : :class:`disco.client.Client` - The base disco client which will be used when constructing models. + client : Optional[:class:`disco.client.Client`] + The Disco client this APIClient is a member of. This is used when constructing + and fitting models from response data. + + Attributes + ---------- + client : Optional[:class:`disco.client.Client`] + The Disco client this APIClient is a member of. + http : :class:`disco.http.HTTPClient` + The HTTPClient this APIClient uses for all requests. """ def __init__(self, token, client=None): super(APIClient, self).__init__() diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 2f97832..e4ba373 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -203,7 +203,7 @@ class Bot(object): grp = group while grp: # If the group already exists, means someone else thought they - # could use it so we need to + # could use it so we need yank it from them (and not use it) if grp in list(six.itervalues(self.group_abbrev)): self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp} else: diff --git a/disco/bot/command.py b/disco/bot/command.py index 1cf00a1..6eb7cc4 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -50,6 +50,9 @@ class CommandEvent(object): @property def codeblock(self): + if '`' not in self.msg.content: + return ' '.join(self.args) + _, src = self.msg.content.split('`', 1) src = '`' + src diff --git a/disco/types/message.py b/disco/types/message.py index e7b2873..500428f 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -368,7 +368,7 @@ class Message(SlottedModel): if user_replace: replace_user = functools.partial(replace, self.mentions.get, user_replace) - content = re.sub('(<@!?([0-9]+)>)', replace_user, self.content) + content = re.sub('(<@!?([0-9]+)>)', replace_user, content) if role_replace: replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace) From cc9ba3d641d055260751dab553ecf0616472de60 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sun, 13 Nov 2016 04:57:16 -0600 Subject: [PATCH 33/91] bugfix - user type should look in state if its not a mention resolver --- disco/bot/command.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index 6eb7cc4..3471af8 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -147,9 +147,12 @@ class Command(object): def resolve_user(ctx, uid): if isinstance(uid, int): - return ctx.msg.mentions.get(uid) + if uid in ctx.msg.mentions: + return ctx.msg.mentions.get(uid) + else: + return ctx.msg.client.state.users.get(uid) else: - return ctx.msg.mentions.select_one(username=uid[0], discriminator=uid[1]) + return ctx.msg.client.state.users.select_one(username=uid[0], discriminator=uid[1]) def resolve_channel(ctx, cid): return ctx.msg.guild.channels.get(cid) From cae3ceff85e3154db27005eeb35cdd5be965e127 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 15 Nov 2016 08:00:14 -0600 Subject: [PATCH 34/91] cleanup - various bits of cleanup --- disco/bot/bot.py | 1 - disco/bot/command.py | 17 ++++++++++++----- disco/bot/plugin.py | 11 ++--------- disco/gateway/client.py | 2 -- disco/types/base.py | 5 +++-- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index e4ba373..451000a 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -64,7 +64,6 @@ class BotConfig(Config): The directory plugin configuration is located within. """ levels = {} - plugins = [] commands_enabled = True commands_require_mention = True diff --git a/disco/bot/command.py b/disco/bot/command.py index 3471af8..a9118b8 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -46,7 +46,10 @@ class CommandEvent(object): self.msg = msg self.match = match self.name = self.match.group(1) - self.args = [i for i in self.match.group(2).strip().split(' ') if i] + self.args = [] + + if self.match.group(2): + self.args = [i for i in self.match.group(2).strip().split(' ') if i] @property def codeblock(self): @@ -133,13 +136,17 @@ class Command(object): self.is_regex = None self.oob = False self.context = {} + self.metadata = {} self.update(*args, **kwargs) + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + def get_docstring(self): return (self.func.__doc__ or '').format(**self.context) - def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None, dispatch_func=None): + def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None, **kwargs): self.triggers += aliases or [] def resolve_role(ctx, rid): @@ -168,7 +175,7 @@ class Command(object): self.is_regex = is_regex self.oob = oob self.context = context or {} - self.dispatch_func = dispatch_func + self.metadata = kwargs @staticmethod def mention_type(getters, reg, user=False): @@ -198,7 +205,7 @@ class Command(object): """ A compiled version of this command's regex. """ - return re.compile(self.regex) + return re.compile(self.regex, re.I) @property def regex(self): @@ -238,4 +245,4 @@ class Command(object): except ArgumentError as e: raise CommandError(e.message) - return (self.dispatch_func or self.func)(event, *args, **self.context) + return self.plugin.dispatch('command', self, event, *args, **self.context) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index fcbc9ad..a99f532 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -244,7 +244,7 @@ class Plugin(LoggingClass, PluginDeco): """ getattr(self, '_' + when)[typ].append(func) - def _dispatch(self, typ, func, event, *args, **kwargs): + def dispatch(self, typ, func, event, *args, **kwargs): # TODO: this is ugly if typ != 'command': self.greenlets.add(gevent.getcurrent()) @@ -283,7 +283,7 @@ class Plugin(LoggingClass, PluginDeco): desc The descriptor of the event/packet. """ - args = list(args) + [functools.partial(self._dispatch, 'listener', func)] + args = list(args) + [functools.partial(self.dispatch, 'listener', func)] if what == 'event': li = self.bot.client.events.on(*args, **kwargs) @@ -308,13 +308,6 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - # name = args[0] - - # if kwargs.pop('update', False) and name in self.commands: - # self.commands[name].update(*args, **kwargs) - # else: - wrapped = functools.partial(self._dispatch, 'command', func) - kwargs.setdefault('dispatch_func', wrapped) self.commands.append(Command(self, func, *args, **kwargs)) def register_schedule(self, func, interval, repeat=True, init=True): diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 5ff6956..a744890 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -147,8 +147,6 @@ class GatewayClient(LoggingClass): raise Exception('WS recieved error: %s', error) def on_open(self): - self.log.info('Opened, headers: %s', self.ws.sock.headers) - if self.seq and self.session_id: self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.send(OPCode.RESUME, { diff --git a/disco/types/base.py b/disco/types/base.py index dfa1e2c..4bee934 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -34,7 +34,8 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None, **kwargs): + def __init__(self, value_type, alias=None, default=None, create=True, **kwargs): + # TODO: fix default bullshit self.src_name = alias self.dst_name = None self.metadata = kwargs @@ -51,7 +52,7 @@ class Field(object): if isinstance(self.deserializer, Field) and self.default is None: self.default = self.deserializer.default - elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None: + elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None and create: self.default = self.deserializer @property From 492b26326ae6ee2b3932a59c2521cb1f8ec17f0d Mon Sep 17 00:00:00 2001 From: Andrei Date: Sun, 20 Nov 2016 21:37:47 -0600 Subject: [PATCH 35/91] More hashmaps, cleanup and fixes --- disco/api/client.py | 6 +++--- disco/api/http.py | 2 +- disco/bot/bot.py | 3 ++- disco/bot/command.py | 2 +- disco/bot/plugin.py | 4 ++-- disco/types/base.py | 7 ++++++- disco/types/message.py | 33 ++++++++++++++++++++------------- disco/types/user.py | 2 +- 8 files changed, 36 insertions(+), 23 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index c8fe702..a478b9b 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -193,7 +193,7 @@ class APIClient(LoggingClass): def guilds_channels_list(self, guild): r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) - return Channel.create_map(self.client, r.json(), guild_id=guild) + return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild) def guilds_channels_create(self, guild, **kwargs): r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) @@ -207,7 +207,7 @@ class APIClient(LoggingClass): def guilds_members_list(self, guild): r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild)) - return GuildMember.create_map(self.client, r.json(), guild_id=guild) + return GuildMember.create_hash(self.client, 'id', r.json(), guild_id=guild) def guilds_members_get(self, guild, member): r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member)) @@ -224,7 +224,7 @@ class APIClient(LoggingClass): def guilds_bans_list(self, guild): r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild)) - return User.create_map(self.client, r.json()) + return User.create_hash(self.client, 'id', r.json()) def guilds_bans_create(self, guild, user, delete_message_days): self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ diff --git a/disco/api/http.py b/disco/api/http.py index 2ebf941..10dcd22 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -211,7 +211,7 @@ class HTTPClient(LoggingClass): # Make the actual request url = self.BASE_URL + route[1].format(**args) - self.log.info('%s %s', route[0].value, url) + self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params')) r = requests.request(route[0].value, url, **kwargs) # Update rate limiter diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 451000a..fccda64 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -214,7 +214,8 @@ class Bot(object): """ Computes a single regex which matches all possible command combinations. """ - re_str = '|'.join(command.regex for command in self.commands) + commands = list(self.commands) + re_str = '|'.join(command.regex for command in commands) if re_str: self.command_matches_re = re.compile(re_str) else: diff --git a/disco/bot/command.py b/disco/bot/command.py index a9118b8..6623a86 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -5,7 +5,7 @@ from holster.enum import Enum from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property -REGEX_FMT = '({})' +REGEX_FMT = '{}' ARGS_REGEX = '( ((?:\n|.)*)$|$)' USER_MENTION_RE = re.compile('<@!?([0-9]+)>') diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a99f532..e8a0b9d 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -258,7 +258,7 @@ class Plugin(LoggingClass, PluginDeco): self.ctx['user'] = event.author for pre in self._pre[typ]: - event = pre(event, args, kwargs) + event = pre(func, event, args, kwargs) if event is None: return False @@ -266,7 +266,7 @@ class Plugin(LoggingClass, PluginDeco): result = func(event, *args, **kwargs) for post in self._post[typ]: - post(event, args, kwargs, result) + post(func, event, args, kwargs, result) return True diff --git a/disco/types/base.py b/disco/types/base.py index 4bee934..0c81e70 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -74,7 +74,6 @@ class Field(object): try: return self.deserializer(raw, client) except Exception as e: - raise six.reraise(ConversionError, ConversionError(self, raw, e)) @staticmethod @@ -337,6 +336,12 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): def create_map(cls, client, data, **kwargs): return list(map(functools.partial(cls.create, client, **kwargs), data)) + @classmethod + def create_hash(cls, client, key, data, **kwargs): + return HashMap({ + getattr(item, key): cls.create(client, item, **kwargs) for item in data + }) + @classmethod def attach(cls, it, data): for item in it: diff --git a/disco/types/message.py b/disco/types/message.py index 500428f..53276c7 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,5 +1,7 @@ import re +import six import functools +import unicodedata from holster.enum import Enum @@ -105,7 +107,7 @@ class MessageEmbed(SlottedModel): title = Field(text) type = Field(str, default='rich') description = Field(text) - url = Field(str) + url = Field(text) timestamp = Field(lazy_datetime) color = Field(int) footer = Field(MessageEmbedFooter) @@ -139,8 +141,8 @@ class MessageAttachment(SlottedModel): """ id = Field(str) filename = Field(text) - url = Field(str) - proxy_url = Field(str) + url = Field(text) + proxy_url = Field(text) size = Field(int) height = Field(int) width = Field(int) @@ -327,13 +329,13 @@ class Message(SlottedModel): @cached_property def with_proper_mentions(self): def replace_user(u): - return '@' + str(u) + return u'@' + six.text_type(u) def replace_role(r): - return '@' + str(r) + return u'@' + six.text_type(r) def replace_channel(c): - return str(c) + return six.text_type(c) return self.replace_mentions(replace_user, replace_role, replace_channel) @@ -382,25 +384,28 @@ class Message(SlottedModel): class MessageTable(object): - def __init__(self, sep=' | ', codeblock=True, header_break=True): + def __init__(self, sep=' | ', codeblock=True, header_break=True, language=None): self.header = [] self.entries = [] self.size_index = {} self.sep = sep self.codeblock = codeblock self.header_break = header_break + self.language = language def recalculate_size_index(self, cols): for idx, col in enumerate(cols): - if idx not in self.size_index or len(col) > self.size_index[idx]: - self.size_index[idx] = len(col) + size = len(unicodedata.normalize('NFC', col)) + if idx not in self.size_index or size > self.size_index[idx]: + self.size_index[idx] = size def set_header(self, *args): + args = list(map(six.text_type, args)) self.header = args self.recalculate_size_index(args) def add(self, *args): - args = list(map(lambda v: v if isinstance(v, basestring) else str(v), args)) + args = list(map(six.text_type, args)) self.entries.append(args) self.recalculate_size_index(args) @@ -414,15 +419,17 @@ class MessageTable(object): return data.rstrip() def compile(self): - data = [self.compile_one(self.header)] + data = [] + if self.header: + data = [self.compile_one(self.header)] - if self.header_break: + if self.header and self.header_break: data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1)) for row in self.entries: data.append(self.compile_one(row)) if self.codeblock: - return '```' + '\n'.join(data) + '```' + return '```{}'.format(self.language if self.language else '') + '\n'.join(data) + '```' return '\n'.join(data) diff --git a/disco/types/user.py b/disco/types/user.py index c2bac79..51bbfbf 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -28,7 +28,7 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): return '<@{}>'.format(self.id) def __str__(self): - return u'{}#{}'.format(self.username, self.discriminator) + return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4)) def __repr__(self): return u''.format(self.id, self) From d80e3c4c57cfa026d02bf5d20f1d386ff2eb9eb7 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 22 Nov 2016 17:25:32 -0600 Subject: [PATCH 36/91] Swap to using RESTful role add/remove endpoints --- disco/api/client.py | 6 ++++++ disco/api/http.py | 2 ++ disco/types/guild.py | 6 ++++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index a478b9b..1ee9026 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -216,6 +216,12 @@ class APIClient(LoggingClass): def guilds_members_modify(self, guild, member, **kwargs): self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) + def guilds_members_roles_add(self, guild, member, role): + self.http(Routes.GUILDS_MEMBERS_ROLES_ADD, dict(guild=guild, member=member, role=role)) + + def guilds_members_roles_remove(self, guild, member, role): + self.http(Routes.GUILDS_MEMBERS_ROLES_REMOVE, dict(guild=guild, member=member, role=role)) + def guilds_members_me_nick(self, guild, nick): self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick}) diff --git a/disco/api/http.py b/disco/api/http.py index 10dcd22..c1c54db 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -71,6 +71,8 @@ class Routes(object): GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') + GUILDS_MEMBERS_ROLES_ADD = (HTTPMethod.PUT, GUILDS + '/members/{member}/roles/{role}') + GUILDS_MEMBERS_ROLES_REMOVE = (HTTPMethod.DELETE, GUILDS + '/members/{member}/roles/{role}') GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') diff --git a/disco/types/guild.py b/disco/types/guild.py index 220ffa4..2206547 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -177,8 +177,10 @@ class GuildMember(SlottedModel): self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') def add_role(self, role): - roles = self.roles + [role.id] - self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles) + self.client.api.guilds_members_roles_add(self.guild.id, self.user.id, to_snowflake(role)) + + def remove_role(self, role): + self.clients.api.guilds_members_roles_remove(self.guild.id, self.user.id, to_snowflake(role)) @cached_property def owner(self): From 708453135304eb4dd7516b63fd788843c436dedf Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 22 Nov 2016 19:14:52 -0600 Subject: [PATCH 37/91] bugfix - parsed command arguments should be passed in kwargs This fixes some weird edge cases with positional arguments when using multiple command definitions for a single function. --- disco/bot/command.py | 7 +++++-- disco/bot/parser.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index 6623a86..c51069e 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -241,8 +241,11 @@ class Command(object): )) try: - args = self.args.parse(event.args, ctx=event) + parsed_args = self.args.parse(event.args, ctx=event) except ArgumentError as e: raise CommandError(e.message) - return self.plugin.dispatch('command', self, event, *args, **self.context) + kwargs = {} + kwargs.update(self.context) + kwargs.update(parsed_args) + return self.plugin.dispatch('command', self, event, **kwargs) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index fab4513..af71337 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -145,7 +145,7 @@ class ArgumentSet(object): """ Parse a string of raw arguments into this argument specification. """ - parsed = [] + parsed = {} for index, arg in enumerate(self.args): if not arg.required and index + arg.true_count > len(rawargs): @@ -171,7 +171,7 @@ class ArgumentSet(object): if (not arg.types or arg.types == ['str']) and isinstance(raw, list): raw = ' '.join(raw) - parsed.append(raw) + parsed[arg.name] = raw return parsed From b7535790fba64b8d4bf0aec1374323de46ce5ec4 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 22 Nov 2016 19:28:28 -0600 Subject: [PATCH 38/91] bugfix - set a proper User-Agent header --- disco/api/http.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/disco/api/http.py b/disco/api/http.py index c1c54db..c583bd7 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -2,9 +2,12 @@ import requests import random import gevent import six +import sys from holster.enum import Enum +from disco import VERSION as disco_version +from requests import __version__ as requests_version from disco.util.logging import LoggingClass from disco.api.ratelimit import RateLimiter @@ -156,9 +159,18 @@ class HTTPClient(LoggingClass): def __init__(self, token): super(HTTPClient, self).__init__() + py_version = '{}.{}.{}'.format( + sys.version_info.major, + sys.version_info.minor, + sys.version_info.micro) + self.limiter = RateLimiter() self.headers = { 'Authorization': 'Bot ' + token, + 'User-Agent': 'DiscordBot (https://github.com/b1naryth1ef/disco {}) Python/{} requests/{}'.format( + disco_version, + py_version, + requests_version), } def __call__(self, route, args=None, **kwargs): From d9015dd3c07fe07d1370027ded2ae03a19e49edf Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 22 Nov 2016 20:17:32 -0600 Subject: [PATCH 39/91] feature - add support for flags/bools in command arguments --- disco/bot/parser.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index af71337..d5d4611 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -4,7 +4,9 @@ import copy # Regex which splits out argument parts -PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])') +PARTS_RE = re.compile('(\<|\[|\{)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\]|\})') + +BOOL_OPTS = {'yes': True, 'no': False, 'true': True, 'False': False, '1': True, '0': False} # Mapping of types TYPE_MAP = { @@ -15,6 +17,14 @@ TYPE_MAP = { } +def to_bool(ctx, data): + if data in BOOL_OPTS: + return BOOL_OPTS[data] + raise TypeError + +TYPE_MAP['bool'] = to_bool + + class ArgumentError(Exception): """ An error thrown when passed in arguments cannot be conformed/casted to the @@ -41,6 +51,7 @@ class Argument(object): self.name = None self.count = 1 self.required = False + self.flag = False self.types = None self.parse(raw) @@ -62,12 +73,16 @@ class Argument(object): else: self.required = False - if part.endswith('...'): - part = part[:-3] - self.count = 0 - elif ' ' in part: - split = part.split(' ', 1) - part, self.count = split[0], int(split[1]) + # Whether this is a flag + self.flag = (prefix == '{') + + if not self.flag: + if part.endswith('...'): + part = part[:-3] + self.count = 0 + elif ' ' in part: + split = part.split(' ', 1) + part, self.count = split[0], int(split[1]) if ':' in part: part, typeinfo = part.split(':') @@ -156,7 +171,15 @@ class ArgumentSet(object): else: raw = rawargs[index:index + arg.true_count] - if arg.types: + if arg.flag: + raw = raw[0].lstrip('-') + if raw == arg.name: + raw = [True] + elif '=' in raw: + raw = [self.convert(ctx, arg.types, raw.split('=', 1)[-1])] + else: + continue + elif arg.types: for idx, r in enumerate(raw): try: raw[idx] = self.convert(ctx, arg.types, r) From 6f8684bb69a65768d9224db56d85604c158caa6a Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 24 Nov 2016 01:38:07 -0600 Subject: [PATCH 40/91] refactor - better parser flags implementation --- disco/bot/parser.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index d5d4611..1df5e9f 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -79,14 +79,16 @@ class Argument(object): if not self.flag: if part.endswith('...'): part = part[:-3] - self.count = 0 + + if self.flag: + raise TypeError('Cannot use nargs on flag') elif ' ' in part: split = part.split(' ', 1) part, self.count = split[0], int(split[1]) - if ':' in part: - part, typeinfo = part.split(':') - self.types = typeinfo.split('|') + if ':' in part: + part, typeinfo = part.split(':') + self.types = typeinfo.split('|') self.name = part.strip() @@ -162,7 +164,22 @@ class ArgumentSet(object): """ parsed = {} - for index, arg in enumerate(self.args): + flags = {i.name: i for i in self.args if i.flag} + if not flags: + return parsed + + new_rawargs = [] + + for offset, raw in enumerate(rawargs): + if raw.startswith('-'): + raw = raw.lstrip('-') + if raw in flags: + parsed[raw] = True + continue + new_rawargs.append(raw) + + rawargs = new_rawargs + for index, arg in enumerate((arg for arg in self.args if not arg.flag)): if not arg.required and index + arg.true_count > len(rawargs): continue @@ -171,15 +188,7 @@ class ArgumentSet(object): else: raw = rawargs[index:index + arg.true_count] - if arg.flag: - raw = raw[0].lstrip('-') - if raw == arg.name: - raw = [True] - elif '=' in raw: - raw = [self.convert(ctx, arg.types, raw.split('=', 1)[-1])] - else: - continue - elif arg.types: + if arg.types: for idx, r in enumerate(raw): try: raw[idx] = self.convert(ctx, arg.types, r) From 5d2d20a42bb89050e66d29a820b4aab99e0c1afb Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 24 Nov 2016 07:55:55 -0600 Subject: [PATCH 41/91] feature - add guild emoji management routes --- disco/api/client.py | 17 ++++++++++++++++- disco/api/http.py | 4 ++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/disco/api/client.py b/disco/api/client.py index 1ee9026..eb8cec5 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -5,7 +5,7 @@ from disco.util.logging import LoggingClass from disco.types.user import User from disco.types.message import Message -from disco.types.guild import Guild, GuildMember, Role +from disco.types.guild import Guild, GuildMember, Role, GuildEmoji from disco.types.channel import Channel from disco.types.invite import Invite from disco.types.webhook import Webhook @@ -263,6 +263,21 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild)) return Webhook.create_map(self.client, r.json()) + def guilds_emojis_list(self, guild): + r = self.http(Routes.GUILDS_EMOJIS_LIST, dict(guild=guild)) + return GuildEmoji.create_map(self.client, r.json()) + + def guilds_emojis_create(self, guild, **kwargs): + r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs) + return GuildEmoji.create(self.client, r.json()) + + def guilds_emojis_modify(self, guild, emoji, **kwargs): + r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs) + return GuildEmoji.create(self.client, r.json()) + + def guilds_emojis_delete(self, guild, emoji): + self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji)) + def invites_get(self, invite): r = self.http(Routes.INVITES_GET, dict(invite=invite)) return Invite.create(self.client, r.json()) diff --git a/disco/api/http.py b/disco/api/http.py index c583bd7..44801ba 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -98,6 +98,10 @@ class Routes(object): GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed') GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed') GUILDS_WEBHOOKS_LIST = (HTTPMethod.GET, GUILDS + '/webhooks') + GUILDS_EMOJIS_LIST = (HTTPMethod.GET, GUILDS + '/emojis') + GUILDS_EMOJIS_CREATE = (HTTPMethod.POST, GUILDS + '/emojis') + GUILDS_EMOJIS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/emojis/{emoji}') + GUILDS_EMOJIS_DELETE = (HTTPMethod.DELETE, GUILDS + '/emojis/{emoji}') # Users USERS = '/users' From 0b2305a7460011f99bc2c886248a4a124ac3d285 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 24 Nov 2016 07:56:25 -0600 Subject: [PATCH 42/91] bugfix - fix flag parsing be super derp --- disco/bot/parser.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 1df5e9f..d6f38db 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -165,20 +165,19 @@ class ArgumentSet(object): parsed = {} flags = {i.name: i for i in self.args if i.flag} - if not flags: - return parsed + if flags: + new_rawargs = [] - new_rawargs = [] + for offset, raw in enumerate(rawargs): + if raw.startswith('-'): + raw = raw.lstrip('-') + if raw in flags: + parsed[raw] = True + continue + new_rawargs.append(raw) - for offset, raw in enumerate(rawargs): - if raw.startswith('-'): - raw = raw.lstrip('-') - if raw in flags: - parsed[raw] = True - continue - new_rawargs.append(raw) + rawargs = new_rawargs - rawargs = new_rawargs for index, arg in enumerate((arg for arg in self.args if not arg.flag)): if not arg.required and index + arg.true_count > len(rawargs): continue From ca36412c2e5f2709f9c8b98e851aa3f9bf589f03 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 24 Nov 2016 07:56:43 -0600 Subject: [PATCH 43/91] feature - track GuildEmojisUpdate in state --- disco/state.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/disco/state.py b/disco/state.py index 81efbb3..86bbcb9 100644 --- a/disco/state.py +++ b/disco/state.py @@ -88,7 +88,7 @@ class State(object): EVENTS = [ 'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove', 'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete', - 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate', + 'GuildEmojisUpdate', 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate', 'PresenceUpdate' ] @@ -296,6 +296,12 @@ class State(object): del self.guilds[event.guild_id].roles[event.role_id] + def on_guild_emojis_update(self, event): + if event.guild_id not in self.guilds: + return + + self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis}) + def on_presence_update(self, event): if event.user.id in self.users: self.users[event.user.id].update(event.presence.user) From b41bcebef94f565641cbe9030666d91cc964c7ab Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 24 Nov 2016 07:56:54 -0600 Subject: [PATCH 44/91] feature - add url property to GuildEmoji --- disco/types/guild.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/disco/types/guild.py b/disco/types/guild.py index 2206547..5048535 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -49,6 +49,10 @@ class GuildEmoji(Emoji): managed = Field(bool) roles = ListField(snowflake) + @property + def url(self): + return 'https://discordapp.com/api/emojis/{}.png'.format(self.id) + @cached_property def guild(self): return self.client.state.guilds.get(self.guild_id) From fca15995798bb377c7e6df9f2e1d93a818b189d8 Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 28 Nov 2016 05:50:53 -0600 Subject: [PATCH 45/91] bugfix - fields without a default but set to None would become UNSET etc stuff as well --- disco/gateway/events.py | 4 ++-- disco/types/base.py | 17 ++++++++++------- disco/types/guild.py | 3 +++ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 5fc28cf..1b4b0b6 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -7,7 +7,7 @@ from disco.types.user import User, Presence from disco.types.channel import Channel from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState -from disco.types.guild import Guild, GuildMember, Role, Emoji +from disco.types.guild import Guild, GuildMember, Role, GuildEmoji from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime @@ -295,7 +295,7 @@ class GuildEmojisUpdate(GatewayEvent): The new set of emojis for the guild """ guild_id = Field(snowflake) - emojis = ListField(Emoji) + emojis = ListField(GuildEmoji) class GuildIntegrationsUpdate(GatewayEvent): diff --git a/disco/types/base.py b/disco/types/base.py index 0c81e70..28e6da6 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -290,18 +290,21 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): should_skip = skip and name in skip if consume and not should_skip: - raw = obj.pop(field.src_name, None) + raw = obj.pop(field.src_name, UNSET) else: - raw = obj.get(field.src_name, None) + raw = obj.get(field.src_name, UNSET) - if raw is None or should_skip: - if field.has_default(): - default = field.default() if callable(field.default) else field.default - else: - default = UNSET + # If the field is unset/none, and we have a default we need to set it + if (raw in (None, UNSET) or should_skip) and field.has_default(): + default = field.default() if callable(field.default) else field.default setattr(self, field.dst_name, default) continue + # Otherwise if the field is UNSET and has no default, skip conversion + if raw is UNSET or should_skip: + setattr(self, field.dst_name, raw) + continue + value = field.try_convert(raw, self.client) setattr(self, field.dst_name, value) diff --git a/disco/types/guild.py b/disco/types/guild.py index 5048535..0b652fa 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -49,6 +49,9 @@ class GuildEmoji(Emoji): managed = Field(bool) roles = ListField(snowflake) + def __str__(self): + return u'<:{}:{}>'.format(self.name, self.id) + @property def url(self): return 'https://discordapp.com/api/emojis/{}.png'.format(self.id) From 93cf4ac660d853cecbc2e0f17295a281a7fa6a12 Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 28 Nov 2016 09:37:04 -0600 Subject: [PATCH 46/91] bugfix - private_channels is a list of channels, not guilds (derp) --- disco/gateway/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 1b4b0b6..48e62ef 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -126,7 +126,7 @@ class Ready(GatewayEvent): session_id = Field(str) user = Field(User) guilds = ListField(Guild) - private_channels = ListField(Guild) + private_channels = ListField(Channel) class Resumed(GatewayEvent): From 8d3e290e032268bf5fe204c99f32ed941b8b65ab Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 28 Nov 2016 15:33:45 -0600 Subject: [PATCH 47/91] feature - add GuildMember.unban utility method --- disco/types/guild.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/disco/types/guild.py b/disco/types/guild.py index 0b652fa..468acbc 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -169,6 +169,12 @@ class GuildMember(SlottedModel): """ self.guild.create_ban(self, delete_message_days) + def unban(self): + """ + Unbans the member from the guild. + """ + self.guild.delete_ban(self) + def set_nickname(self, nickname=None): """ Sets the member's nickname (or clears it if None). From b85376be3262893aa2e527dbc2e73185ea97dbab Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 1 Dec 2016 14:26:06 -0600 Subject: [PATCH 48/91] Add MessageReactionRemoveAll, etc fixes - Fix passing plugins on command line - Fix shard count in AutoSharder --- disco/cli.py | 5 ++++- disco/gateway/events.py | 15 +++++++++++++++ disco/gateway/sharder.py | 2 -- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/disco/cli.py b/disco/cli.py index 4b9d95b..10aa9f3 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -69,7 +69,10 @@ def disco_main(run=False): bot = None if args.run_bot or hasattr(config, 'bot'): bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig() - bot_config.plugins += args.plugin + if not hasattr(bot_config, 'plugins'): + bot_config.plugins = args.plugin + else: + bot_config.plugins += args.plugin bot = Bot(client, bot_config) if run: diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 48e62ef..958b7a9 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -633,3 +633,18 @@ class MessageReactionRemove(GatewayEvent): @property def guild(self): return self.channel.guild + + +class MessageReactionRemoveAll(GatewayEvent): + """ + Sent when all reactions are removed from a message. + + Attributes + ---------- + channel_id : snowflake + The channel ID the message is in. + message_id : snowflake + The ID of the message for which the reactions where removed from. + """ + channel_id = Field(snowflake) + message_id = Field(snowflake) diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py index 0321cee..99e8d1a 100644 --- a/disco/gateway/sharder.py +++ b/disco/gateway/sharder.py @@ -63,8 +63,6 @@ class AutoSharder(object): self.client = APIClient(config.token) self.shards = {} self.config.shard_count = self.client.gateway_bot_get()['shards'] - if self.config.shard_count > 1: - self.config.shard_count = 10 def run_on(self, sid, raw): func = load_function(raw) From bef1492d543bd1eb4862b2b6098d53f4b0611c3b Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 1 Dec 2016 16:54:43 -0600 Subject: [PATCH 49/91] Bump holster, fix SimpleLimiter race error, fix to_dict recursion --- disco/types/base.py | 28 +++++++++++++++++++--------- disco/types/user.py | 2 +- disco/util/limiter.py | 3 ++- requirements.txt | 2 +- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/disco/types/base.py b/disco/types/base.py index 28e6da6..f72d8bb 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -34,10 +34,11 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None, create=True, **kwargs): + def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, **kwargs): # TODO: fix default bullshit self.src_name = alias self.dst_name = None + self.ignore_dump = ignore_dump or [] self.metadata = kwargs if default is not None: @@ -88,11 +89,11 @@ class Field(object): return lambda raw, _: typ(raw) @staticmethod - def serialize(value): + def serialize(value, inst=None): if isinstance(value, EnumAttr): return value.value elif isinstance(value, Model): - return value.to_dict() + return value.to_dict(ignore=(inst.ignore_dump if inst else [])) else: return value @@ -109,8 +110,11 @@ class DictField(Field): self.value_de = self.type_to_deserializer(value_type or key_type) @staticmethod - def serialize(value): - return {Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value)} + def serialize(value, inst=None): + return { + Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value) + if k not in (inst.ignore_dump if inst else []) + } def try_convert(self, raw, client): return HashMap({ @@ -122,7 +126,7 @@ class ListField(Field): default = list @staticmethod - def serialize(value): + def serialize(value, inst=None): return list(map(Field.serialize, value)) def try_convert(self, raw, client): @@ -211,7 +215,10 @@ def binary(obj): def with_equality(field): class T(object): def __eq__(self, other): - return getattr(self, field) == getattr(other, field) + if isinstance(other, self.__class__): + return getattr(self, field) == getattr(other, field) + else: + return getattr(self, field) == other return T @@ -321,12 +328,15 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): except: pass - def to_dict(self): + def to_dict(self, ignore=None): obj = {} for name, field in six.iteritems(self.__class__._fields): + if ignore and name in ignore: + continue + if getattr(self, name) == UNSET: continue - obj[name] = field.serialize(getattr(self, name)) + obj[name] = field.serialize(getattr(self, name), field) return obj @classmethod diff --git a/disco/types/user.py b/disco/types/user.py index 51bbfbf..b04c180 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -55,6 +55,6 @@ class Game(SlottedModel): class Presence(SlottedModel): - user = Field(User, alias='user') + user = Field(User, alias='user', ignore_dump=['presence']) game = Field(Game) status = Field(Status) diff --git a/disco/util/limiter.py b/disco/util/limiter.py index 6992832..ccb7622 100644 --- a/disco/util/limiter.py +++ b/disco/util/limiter.py @@ -17,7 +17,8 @@ class SimpleLimiter(object): gevent.sleep(self.reset_at - time.time()) self.count = 0 self.reset_at = 0 - self.event.set() + if self.event: + self.event.set() self.event = None def check(self): diff --git a/requirements.txt b/requirements.txt index a41d421..3b0894b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ gevent==1.1.2 -holster==1.0.9 +holster==1.0.11 inflection==0.3.1 requests==2.11.1 six==1.10.0 From 2e7b707e1e0f9a1da3955ea7585eaf0d7e8d45f4 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 2 Dec 2016 15:56:55 -0600 Subject: [PATCH 50/91] bugfix - n-size arguments would be limited to one argument --- disco/bot/parser.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index d6f38db..16ee3f2 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -79,9 +79,7 @@ class Argument(object): if not self.flag: if part.endswith('...'): part = part[:-3] - - if self.flag: - raise TypeError('Cannot use nargs on flag') + self.count = 0 elif ' ' in part: split = part.split(' ', 1) part, self.count = split[0], int(split[1]) From 704f784bef044314c2f677d7723e56100e45ffb4 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 3 Dec 2016 20:58:03 -0600 Subject: [PATCH 51/91] refactor - make APIException less corner-casey, easier to interact with --- disco/api/http.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/disco/api/http.py b/disco/api/http.py index 44801ba..dfa356d 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -141,13 +141,27 @@ class APIException(Exception): The status code returned by the API for the request that triggered this error. """ - def __init__(self, msg, status_code=0, content=None): - self.status_code = status_code - self.content = content - self.msg = msg - - if self.status_code: - self.msg += ' code: {}'.format(status_code) + def __init__(self, response, retries=None): + self.response = response + self.code = 0 + self.msg = 'Request Failed ({})'.format(response.status_code) + + # Try to decode JSON, and extract params + try: + data = self.response.json() + + if 'code' in data: + self.code = data['code'] + self.msg = data['message'] + elif len(data) == 1: + key, value = list(data.items())[0] + self.msg = 'Request Failed: {}: {}'.format(key, ', '.join(value)) + except ValueError: + pass + + # DEPRECATED: left for backwards compat + self.status_code = response.status_code + self.content = response.content super(APIException, self).__init__(self.msg) @@ -239,7 +253,7 @@ class HTTPClient(LoggingClass): if r.status_code < 400: return r elif r.status_code != 429 and 400 <= r.status_code < 500: - raise APIException('Request failed', r.status_code, r.content) + raise APIException(r) else: if r.status_code == 429: self.log.warning( @@ -249,8 +263,7 @@ class HTTPClient(LoggingClass): retry += 1 if retry > self.MAX_RETRIES: self.log.error('Failing request, hit max retries') - raise APIException( - 'Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) + raise APIException(r, retries=self.MAX_RETRIES) backoff = self.random_backoff() self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format( From 48c8fa6fa9c5ef00f94d2422bfa45b64cc0871f7 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 3 Dec 2016 20:58:21 -0600 Subject: [PATCH 52/91] feature -Add some logging to Bot, allow for handling of command/event exceptions --- disco/bot/bot.py | 6 ++++-- disco/bot/plugin.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index fccda64..d1a31f6 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -12,6 +12,7 @@ from disco.bot.plugin import Plugin from disco.bot.command import CommandEvent, CommandLevels from disco.bot.storage import Storage from disco.util.config import Config +from disco.util.logging import LoggingClass from disco.util.serializer import Serializer @@ -87,7 +88,7 @@ class BotConfig(Config): storage_config = {} -class Bot(object): +class Bot(LoggingClass): """ Disco's implementation of a simple but extendable Discord bot. Bots consist of a set of plugins, and a Disco client. @@ -380,6 +381,7 @@ class Bot(object): unload. """ if cls.__name__ in self.plugins: + self.log.warning('Attempted to add already added plugin %s', cls.__name__) raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) if not config: @@ -431,7 +433,7 @@ class Bot(object): """ Adds and loads a plugin, based on its module path. """ - + self.log.info('Adding plugin module at path "%s"', path) mod = importlib.import_module(path) loaded = False diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index e8a0b9d..a4ef0ea 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -205,6 +205,9 @@ class Plugin(LoggingClass, PluginDeco): when, typ = meta['type'].split('_', 1) self.register_trigger(typ, when, member) + def handle_exception(self, greenlet, event): + pass + def spawn_wrap(self, spawner, method, *args, **kwargs): def wrapped(*args, **kwargs): self.ctx['plugin'] = self @@ -245,9 +248,13 @@ class Plugin(LoggingClass, PluginDeco): getattr(self, '_' + when)[typ].append(func) def dispatch(self, typ, func, event, *args, **kwargs): + # Link the greenlet with our exception handler + gevent.getcurrent().link_exception(lambda g: self.handle_exception(g, event)) + # TODO: this is ugly if typ != 'command': self.greenlets.add(gevent.getcurrent()) + self.ctx['plugin'] = self if hasattr(event, 'guild'): From e0ea431d9d71765ff71cbba18a4181fe9917ce95 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 3 Dec 2016 20:59:16 -0600 Subject: [PATCH 53/91] cleanup - actually use retries var in APIException --- disco/api/http.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/disco/api/http.py b/disco/api/http.py index dfa356d..088b69e 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -143,9 +143,14 @@ class APIException(Exception): """ def __init__(self, response, retries=None): self.response = response + self.retries = retries + self.code = 0 self.msg = 'Request Failed ({})'.format(response.status_code) + if self.retries: + self.msg += " after {} retries".format(self.retries) + # Try to decode JSON, and extract params try: data = self.response.json() From 57b7e4b5497e2b3b23e7d461bef3115201c5f3b1 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 3 Dec 2016 21:19:38 -0600 Subject: [PATCH 54/91] feature - add guild "mention" type to parser --- disco/bot/command.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index c51069e..1ed5d9e 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -164,10 +164,14 @@ class Command(object): def resolve_channel(ctx, cid): return ctx.msg.guild.channels.get(cid) + def resolve_guild(ctx, gid): + return ctx.msg.client.state.guilds.get(gid) + self.args = ArgumentSet.from_string(args or '', { 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True), 'role': self.mention_type([resolve_role], ROLE_MENTION_RE), 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE), + 'guild': self.mention_type([resolve_guild]), }) self.level = level @@ -178,19 +182,21 @@ class Command(object): self.metadata = kwargs @staticmethod - def mention_type(getters, reg, user=False): + def mention_type(getters, reg=None, user=False): def _f(ctx, raw): if raw.isdigit(): resolved = int(raw) elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit(): username, discrim = raw.split('#') resolved = (username, int(discrim)) - else: + elif reg: res = reg.match(raw) if not res: raise TypeError('Invalid mention: {}'.format(raw)) resolved = int(res.group(1)) + else: + raise TypeError('Invalid mention: {}'.format(raw)) for getter in getters: obj = getter(ctx, resolved) From ec85c25342101d91345e6134c286f85dc02060ad Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 15 Dec 2016 12:38:59 -0600 Subject: [PATCH 55/91] bugfix - reset reconnection attempts after a successful resume --- disco/gateway/client.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index a744890..9544af5 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -42,6 +42,7 @@ class GatewayClient(LoggingClass): # Bind to ready payload self.events.on('Ready', self.on_ready) + self.events.on('Resumed', self.on_resumed) # Websocket connection self.ws = None @@ -103,6 +104,10 @@ class GatewayClient(LoggingClass): self.session_id = ready.session_id self.reconnects = 0 + def on_resumed(self, _): + self.log.info('Recieved RESUMED') + self.reconnects = 0 + def connect_and_run(self, gateway_url=None): if not gateway_url: if not self._cached_gateway_url: @@ -188,7 +193,7 @@ class GatewayClient(LoggingClass): self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS: - raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) + raise Exception('Failed to reconnect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) # Don't resume for these error codes if code and 4000 <= code <= 4010: From a498f08217899651161b43218bf4533386f7f16d Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 22 Dec 2016 17:37:01 -0600 Subject: [PATCH 56/91] feature - add ability to control gateway max reconnects vi config --- disco/client.py | 5 ++++- disco/gateway/client.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/disco/client.py b/disco/client.py index 0de6f45..a5791ee 100644 --- a/disco/client.py +++ b/disco/client.py @@ -26,6 +26,8 @@ class ClientConfig(Config): The shard ID for the current client instance. shard_count : int The total count of shards running. + max_reconnects : int + The maximum number of connection retries to make before giving up (0 = never give up). manhole_enable : bool Whether to enable the manhole (e.g. console backdoor server) utility. manhole_bind : tuple(str, int) @@ -39,6 +41,7 @@ class ClientConfig(Config): token = "" shard_id = 0 shard_count = 1 + max_reconnects = 5 manhole_enable = False manhole_bind = ('127.0.0.1', 8484) @@ -86,7 +89,7 @@ class Client(LoggingClass): self.packets = Emitter(gevent.spawn) self.api = APIClient(self.config.token, self) - self.gw = GatewayClient(self, self.config.encoder) + self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder) self.state = State(self, StateConfig(self.config.get('state', {}))) if self.config.manhole_enable: diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 9544af5..7386bb5 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -15,11 +15,11 @@ TEN_MEGABYTES = 10490000 class GatewayClient(LoggingClass): GATEWAY_VERSION = 6 - MAX_RECONNECTS = 5 - def __init__(self, client, encoder='json', ipc=None): + def __init__(self, client, max_reconnects=5, encoder='json', ipc=None): super(GatewayClient, self).__init__() self.client = client + self.max_reconnects = max_reconnects self.encoder = ENCODERS[encoder] self.events = client.events @@ -192,8 +192,8 @@ class GatewayClient(LoggingClass): self.reconnects += 1 self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) - if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS: - raise Exception('Failed to reconnect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) + if self.max_reconnects and self.reconnects > self.max_reconnects: + raise Exception('Failed to reconnect after {} attempts, giving up'.format(self.max_reconnects)) # Don't resume for these error codes if code and 4000 <= code <= 4010: From 21d54969e709ab0b6400237da72590ee6e9b869c Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 22 Dec 2016 17:37:34 -0600 Subject: [PATCH 57/91] bugfix - properly send embed when using multipart-form for attachments also fix sending an embed on edit --- disco/api/client.py | 7 ++++--- disco/types/message.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index eb8cec5..aa3aff4 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -1,4 +1,5 @@ import six +import json from disco.api.http import Routes, HTTPClient from disco.util.logging import LoggingClass @@ -98,7 +99,7 @@ class APIClient(LoggingClass): payload['embed'] = embed.to_dict() if attachment: - r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data=payload, files={ + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data={'payload_json': json.dumps(payload)}, files={ 'file': (attachment[0], attachment[1]) }) else: @@ -106,10 +107,10 @@ class APIClient(LoggingClass): return Message.create(self.client, r.json()) - def channels_messages_modify(self, channel, message, content): + def channels_messages_modify(self, channel, message, content, embed=None): r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, dict(channel=channel, message=message), - json={'content': content}) + json={'content': content, 'embed': embed.to_dict()}) return Message.create(self.client, r.json()) def channels_messages_delete(self, channel, message): diff --git a/disco/types/message.py b/disco/types/message.py index 53276c7..e9cd82f 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -254,7 +254,7 @@ class Message(SlottedModel): """ return self.channel.send_message(*args, **kwargs) - def edit(self, content): + def edit(self, *args, **kwargs): """ Edit this message. @@ -268,7 +268,7 @@ class Message(SlottedModel): :class:`Message` The edited message object. """ - return self.client.api.channels_messages_modify(self.channel_id, self.id, content) + return self.client.api.channels_messages_modify(self.channel_id, self.id, *args, **kwargs) def delete(self): """ From 013fd544448f1b28c8e7e62ded069f4d3f84ed29 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 27 Dec 2016 01:58:59 -0600 Subject: [PATCH 58/91] cleanup - Simplfy ARGS_REGEX even more This still needs to not be a shitty regex --- disco/bot/command.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index 1ed5d9e..b2283da 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -6,7 +6,7 @@ from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property REGEX_FMT = '{}' -ARGS_REGEX = '( ((?:\n|.)*)$|$)' +ARGS_REGEX = '(?: ((?:\n|.)*)$|$)' USER_MENTION_RE = re.compile('<@!?([0-9]+)>') ROLE_MENTION_RE = re.compile('<@&([0-9]+)>') @@ -45,11 +45,11 @@ class CommandEvent(object): self.command = command self.msg = msg self.match = match - self.name = self.match.group(1) + self.name = self.match.group(0) self.args = [] - if self.match.group(2): - self.args = [i for i in self.match.group(2).strip().split(' ') if i] + if self.match.group(1): + self.args = [i for i in self.match.group(1).strip().split(' ') if i] @property def codeblock(self): From a3924440b983c272726984094a4126b255e56602 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 4 Jan 2017 06:26:14 -0600 Subject: [PATCH 59/91] etc - better default log format --- disco/util/logging.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/disco/util/logging.py b/disco/util/logging.py index 5ce9498..68af8a8 100644 --- a/disco/util/logging.py +++ b/disco/util/logging.py @@ -7,8 +7,11 @@ LEVEL_OVERRIDES = { 'requests': logging.WARNING } +LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s' def setup_logging(**kwargs): + kwargs.setdefault('format', LOG_FORMAT) + logging.basicConfig(**kwargs) for logger, level in LEVEL_OVERRIDES.items(): logging.getLogger(logger).setLevel(level) From ceb57e19d43883c5b120677efd6be29fb72c6dfb Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 4 Jan 2017 09:39:59 -0600 Subject: [PATCH 60/91] bugfix - fix `create_hash` not reading from the created object also add GuildBan --- disco/api/client.py | 4 ++-- disco/types/base.py | 10 +++++++++- disco/types/guild.py | 5 +++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index aa3aff4..68c551d 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -6,7 +6,7 @@ from disco.util.logging import LoggingClass from disco.types.user import User from disco.types.message import Message -from disco.types.guild import Guild, GuildMember, Role, GuildEmoji +from disco.types.guild import Guild, GuildMember, GuildBan, Role, GuildEmoji from disco.types.channel import Channel from disco.types.invite import Invite from disco.types.webhook import Webhook @@ -231,7 +231,7 @@ class APIClient(LoggingClass): def guilds_bans_list(self, guild): r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild)) - return User.create_hash(self.client, 'id', r.json()) + return GuildBan.create_hash(self.client, 'user.id', r.json()) def guilds_bans_create(self, guild, user, delete_message_days): self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ diff --git a/disco/types/base.py b/disco/types/base.py index f72d8bb..0cf0646 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -15,6 +15,12 @@ DATETIME_FORMATS = [ ] +def get_item_by_path(obj, path): + for part in path.split('.'): + obj = getattr(obj, part) + return obj + + class Unset(object): def __nonzero__(self): return False @@ -352,7 +358,9 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): @classmethod def create_hash(cls, client, key, data, **kwargs): return HashMap({ - getattr(item, key): cls.create(client, item, **kwargs) for item in data + get_item_by_path(item, key): item + for item in [ + cls.create(client, item, **kwargs) for item in data] }) @classmethod diff --git a/disco/types/guild.py b/disco/types/guild.py index 468acbc..a9075bb 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -110,6 +110,11 @@ class Role(SlottedModel): return self.client.state.guilds.get(self.guild_id) +class GuildBan(SlottedModel): + user = Field(User) + reason = Field(str) + + class GuildMember(SlottedModel): """ A GuildMember object. From 29768b84d24ca238684aae7b6ba37026cf4ca6e2 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 4 Jan 2017 10:26:41 -0600 Subject: [PATCH 61/91] feature - add trace field to Ready and Resumed events --- disco/gateway/events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 958b7a9..6430d84 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -127,13 +127,14 @@ class Ready(GatewayEvent): user = Field(User) guilds = ListField(Guild) private_channels = ListField(Channel) + trace = ListField(str, alias='_trace') class Resumed(GatewayEvent): """ Sent after a resume completes. """ - pass + trace = ListField(str, alias='_trace') @wraps_model(Guild) From 4f4c3ab2ff91612f9d5b815c43de6844eb5f38fe Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 4 Jan 2017 10:26:55 -0600 Subject: [PATCH 62/91] feature - Channel.recipients should be a hash map --- disco/types/channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/disco/types/channel.py b/disco/types/channel.py index 3912535..4a26c05 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -5,7 +5,7 @@ from holster.enum import Enum from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property, one_or_many, chunks from disco.types.user import User -from disco.types.base import SlottedModel, Field, ListField, AutoDictField, snowflake, enum, text +from disco.types.base import SlottedModel, Field, AutoDictField, snowflake, enum, text from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.voice.client import VoiceClient @@ -111,7 +111,7 @@ class Channel(SlottedModel, Permissible): last_message_id = Field(snowflake) position = Field(int) bitrate = Field(int) - recipients = ListField(User) + recipients = AutoDictField(User, 'id') type = Field(enum(ChannelType)) overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') From 514a99dcead85015e6a9d78d766a2963fa0a7d16 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 6 Jan 2017 18:03:34 -0600 Subject: [PATCH 63/91] feature - add Guild.create_channel --- disco/api/client.py | 23 +++++++++++++++++++++-- disco/types/guild.py | 4 ++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index 68c551d..9f5b043 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -196,8 +196,27 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild) - def guilds_channels_create(self, guild, **kwargs): - r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) + def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[]): + payload = { + 'name': name, + 'channel_type': channel_type, + 'permission_overwrites': [i.to_dict() for i in permission_overwrites], + } + + if channel_type == 'text': + pass + elif channel_type == 'voice': + if bitrate is not None: + payload['bitrate'] = bitrate + + if user_limit is not None: + payload['user_limit'] = user_limit + else: + # TODO: better error here? + raise Exception('Invalid channel type: {}'.format(channel_type)) + + + r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=payload) return Channel.create(self.client, r.json(), guild_id=guild) def guilds_channels_modify(self, guild, channel, position): diff --git a/disco/types/guild.py b/disco/types/guild.py index a9075bb..95baa3f 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -405,3 +405,7 @@ class Guild(SlottedModel, Permissible): def create_ban(self, user, delete_message_days=0): self.client.api.guilds_bans_create(self.id, to_snowflake(user), delete_message_days) + + def create_channel(self, *args, **kwargs): + return self.client.api.guilds_channels_create(self.id, *args, **kwargs) + From 75677ab27267d35aa6d20d21ada718ef98b606a0 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 6 Jan 2017 18:04:03 -0600 Subject: [PATCH 64/91] feature - channel argtype should support channel names --- disco/bot/command.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index b2283da..7e9f5ef 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -162,7 +162,10 @@ class Command(object): return ctx.msg.client.state.users.select_one(username=uid[0], discriminator=uid[1]) def resolve_channel(ctx, cid): - return ctx.msg.guild.channels.get(cid) + if isinstance(cid, (int, long)): + return ctx.msg.guild.channels.get(cid) + else: + return ctx.msg.guild.channels.select_one(name=cid) def resolve_guild(ctx, gid): return ctx.msg.client.state.guilds.get(gid) @@ -170,7 +173,7 @@ class Command(object): self.args = ArgumentSet.from_string(args or '', { 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True), 'role': self.mention_type([resolve_role], ROLE_MENTION_RE), - 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE), + 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE, allow_plain=True), 'guild': self.mention_type([resolve_guild]), }) @@ -182,7 +185,7 @@ class Command(object): self.metadata = kwargs @staticmethod - def mention_type(getters, reg=None, user=False): + def mention_type(getters, reg=None, user=False, allow_plain=False): def _f(ctx, raw): if raw.isdigit(): resolved = int(raw) @@ -191,10 +194,13 @@ class Command(object): resolved = (username, int(discrim)) elif reg: res = reg.match(raw) - if not res: - raise TypeError('Invalid mention: {}'.format(raw)) - - resolved = int(res.group(1)) + if res: + resolved = int(res.group(1)) + else: + if allow_plain: + resolved = raw + else: + raise TypeError('Invalid mention: {}'.format(raw)) else: raise TypeError('Invalid mention: {}'.format(raw)) From eef13d43306d30e28e2a9f7893a3d1b5b27f9f15 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 6 Jan 2017 18:04:23 -0600 Subject: [PATCH 65/91] hackfixes - various fixes, incl fucky CHANNEL_UPDATE stuff --- disco/gateway/events.py | 5 +++-- disco/state.py | 5 +++++ disco/types/base.py | 5 ++++- disco/types/channel.py | 14 ++++++++++++-- disco/types/permissions.py | 5 ++++- 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6430d84..f74431e 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -4,12 +4,12 @@ import inflection import six from disco.types.user import User, Presence -from disco.types.channel import Channel +from disco.types.channel import Channel, PermissionOverwrite from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role, GuildEmoji -from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime +from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, lazy_datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} @@ -217,6 +217,7 @@ class ChannelUpdate(ChannelCreate): channel : :class:`disco.types.channel.Channel` The channel which was updated. """ + overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') @wraps_model(Channel) diff --git a/disco/state.py b/disco/state.py index 86bbcb9..ddca31c 100644 --- a/disco/state.py +++ b/disco/state.py @@ -5,6 +5,7 @@ import inflection from collections import deque, namedtuple from gevent.event import Event +from disco.types.base import UNSET from disco.util.config import Config from disco.util.hashmap import HashMap, DefaultHashMap @@ -211,6 +212,10 @@ class State(object): if event.channel.id in self.channels: self.channels[event.channel.id].update(event.channel) + if event.overwrites is not UNSET: + self.channels[event.channel.id].overwrites = event.overwrites + self.channels[event.channel.id].after_load() + def on_channel_delete(self, event): if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels: del event.channel.guild.channels[event.channel.id] diff --git a/disco/types/base.py b/disco/types/base.py index 0cf0646..d12de6e 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -40,11 +40,12 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, **kwargs): + def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, cast=None, **kwargs): # TODO: fix default bullshit self.src_name = alias self.dst_name = None self.ignore_dump = ignore_dump or [] + self.cast = cast self.metadata = kwargs if default is not None: @@ -101,6 +102,8 @@ class Field(object): elif isinstance(value, Model): return value.to_dict(ignore=(inst.ignore_dump if inst else [])) else: + if inst and inst.cast: + return inst.cast(value) return value def __call__(self, raw, client): diff --git a/disco/types/channel.py b/disco/types/channel.py index 4a26c05..664fd94 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -48,8 +48,8 @@ class PermissionOverwrite(ChannelSubType): """ id = Field(snowflake) type = Field(enum(PermissionOverwriteType)) - allow = Field(PermissionValue) - deny = Field(PermissionValue) + allow = Field(PermissionValue, cast=int) + deny = Field(PermissionValue, cast=int) channel_id = Field(snowflake) @@ -67,6 +67,13 @@ class PermissionOverwrite(ChannelSubType): channel_id=channel.id ).save() + @property + def compiled(self): + value = PermissionValue() + value -= self.deny + value += self.allow + return value + def save(self): self.client.api.channels_permissions_modify(self.channel_id, self.id, @@ -117,7 +124,10 @@ class Channel(SlottedModel, Permissible): def __init__(self, *args, **kwargs): super(Channel, self).__init__(*args, **kwargs) + self.after_load() + def after_load(self): + # TODO: hackfix self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) def __str__(self): diff --git a/disco/types/permissions.py b/disco/types/permissions.py index ff43145..6e4d9f3 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -90,9 +90,12 @@ class PermissionValue(object): else: self.value &= ~Permissions[name].value + def __int__(self): + return self.value + def to_dict(self): return { - k: getattr(self, k) for k in Permissions.attrs + k: getattr(self, k) for k in Permissions.keys_ } @classmethod From 1bba005c34151b4a75d56e65db3a03aed3104cbd Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 11 Jan 2017 20:44:38 -0600 Subject: [PATCH 66/91] feature - add some utility functions for constructing MessageEmbed --- disco/types/message.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/disco/types/message.py b/disco/types/message.py index e9cd82f..7570b6a 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -117,6 +117,24 @@ class MessageEmbed(SlottedModel): author = Field(MessageEmbedAuthor) fields = ListField(MessageEmbedField) + def set_footer(self, *args, **kwargs): + self.footer = MessageEmbedField(*args, **kwargs) + + def set_image(self, *args, **kwargs): + self.image = MessageEmbedImage(*args, **kwargs) + + def set_thumbnail(self, *args, **kwargs): + self.thumbnail = MessageEmbedThumbnail(*args, **kwargs) + + def set_video(self, *args, **kwargs): + self.video = MessageEmbedVideo(*args, **kwargs) + + def set_author(self, *args, **kwargs): + self.author = MessageEmbedAuthor(*args, **kwargs) + + def add_field(self, *args, **kwargs): + self.fields.append(MessageEmbedField(*args, **kwargs)) + class MessageAttachment(SlottedModel): """ From d78fed6641fc1f2c084d93c01f3f742bbce017a3 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 11 Jan 2017 20:44:56 -0600 Subject: [PATCH 67/91] bugfix - fix editing a message with embeds --- disco/api/client.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/disco/api/client.py b/disco/api/client.py index 9f5b043..d6c905b 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -108,9 +108,16 @@ class APIClient(LoggingClass): return Message.create(self.client, r.json()) def channels_messages_modify(self, channel, message, content, embed=None): + payload = { + 'content': content, + } + + if embed: + payload['embed'] = embed.to_dict() + r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, dict(channel=channel, message=message), - json={'content': content, 'embed': embed.to_dict()}) + json=payload) return Message.create(self.client, r.json()) def channels_messages_delete(self, channel, message): From 8d0dbeba1efb7dca79019e51c1d9b1be465b41cf Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 11 Jan 2017 21:00:04 -0600 Subject: [PATCH 68/91] bugfix - properly combine multiple triggers (fixes #14) --- disco/bot/command.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/disco/bot/command.py b/disco/bot/command.py index 7e9f5ef..80f6929 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -5,7 +5,6 @@ from holster.enum import Enum from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property -REGEX_FMT = '{}' ARGS_REGEX = '(?: ((?:\n|.)*)$|$)' USER_MENTION_RE = re.compile('<@!?([0-9]+)>') @@ -225,7 +224,7 @@ class Command(object): The regex string that defines/triggers this command. """ if self.is_regex: - return REGEX_FMT.format('|'.join(self.triggers)) + return '|'.join(self.triggers) else: group = '' if self.group: @@ -233,7 +232,7 @@ class Command(object): group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group)) else: group = self.group + ' ' - return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) + return '^{}(?:{})'.format(group, '|'.join(self.triggers)) + ARGS_REGEX def execute(self, event): """ From d2d7166e6692b6af8d6b37629c3e4bfe7e7f6183 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 11 Jan 2017 21:09:57 -0600 Subject: [PATCH 69/91] add User.get_avatar_url to support more formats/sizes --- disco/types/user.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/disco/types/user.py b/disco/types/user.py index b04c180..d4c32cd 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -14,14 +14,20 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): presence = Field(None) - @property - def avatar_url(self): + def get_avatar_url(self, fmt='webp', size=1024): if not self.avatar: return None - return 'https://discordapp.com/api/users/{}/avatars/{}.jpg'.format( + return 'https://cdn.discordapp.com/avatars/{}/{}.{}?size={}'.format( self.id, - self.avatar) + self.avatar, + fmt, + size + ) + + @property + def avatar_url(self): + return self.get_avatar_url() @property def mention(self): From a10aca7f4ac3fccf2c06b49eeff71f7a1f9effce Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 26 Jan 2017 02:33:22 -0800 Subject: [PATCH 70/91] Allow overriding plugin configuration in the top-level bot config --- disco/bot/bot.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index d1a31f6..7a8a433 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -65,6 +65,7 @@ class BotConfig(Config): The directory plugin configuration is located within. """ levels = {} + plugin_config = {} commands_enabled = True commands_require_mention = True @@ -455,17 +456,18 @@ class Bot(LoggingClass): path = os.path.join( self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format - if not os.path.exists(path): - if hasattr(cls, 'config_cls'): - return cls.config_cls() - return + data = {} + if name in self.config.plugin_config: + data = self.config.plugin_config[name] - with open(path, 'r') as f: - data = Serializer.loads(self.config.plugin_config_format, f.read()) + if os.path.exists(path): + with open(path, 'r') as f: + data.update(Serializer.loads(self.config.plugin_config_format, f.read())) if hasattr(cls, 'config_cls'): inst = cls.config_cls() - inst.update(data) + if data: + inst.update(data) return inst return data From 729f778ef4238cd14b4b38840a3d7adfe53e6b45 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 26 Jan 2017 02:34:08 -0800 Subject: [PATCH 71/91] Add utility properties for MessageDeleteBulk --- disco/gateway/events.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index f74431e..fe5c669 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -493,6 +493,14 @@ class MessageDeleteBulk(GatewayEvent): channel_id = Field(snowflake) ids = ListField(snowflake) + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + @wraps_model(Presence) class PresenceUpdate(GatewayEvent): From 8216ec4ad774aef653c9d888ab8cb55fac71836a Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 26 Jan 2017 02:34:29 -0800 Subject: [PATCH 72/91] Add APIClient.users_me_get and APIClient.users_me_patch Not currently hooked up, but useful nonetheless --- disco/api/client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/disco/api/client.py b/disco/api/client.py index d6c905b..c0f90ea 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -222,7 +222,6 @@ class APIClient(LoggingClass): # TODO: better error here? raise Exception('Invalid channel type: {}'.format(channel_type)) - r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=payload) return Channel.create(self.client, r.json(), guild_id=guild) @@ -305,6 +304,13 @@ class APIClient(LoggingClass): def guilds_emojis_delete(self, guild, emoji): self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji)) + def users_me_get(self): + return User.create(self.client, self.http(Routes.USERS_ME_GET).json()) + + def users_me_patch(self, payload): + r = self.http(Routes.USERS_ME_PATCH, json=payload) + return User.create(self.client, r.json()) + def invites_get(self, invite): r = self.http(Routes.INVITES_GET, dict(invite=invite)) return Invite.create(self.client, r.json()) From 835242c2413eca8f3f378b5c67ec5c3c10b6e330 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 26 Jan 2017 02:34:53 -0800 Subject: [PATCH 73/91] Add GuildMember.name property, returning nickname or username --- disco/types/guild.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/disco/types/guild.py b/disco/types/guild.py index 95baa3f..0d2952e 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -147,6 +147,13 @@ class GuildMember(SlottedModel): def __str__(self): return self.user.__str__() + @property + def name(self): + """ + The nickname of this user if set, otherwise their username + """ + return self.nick or self.user.username + def get_voice_state(self): """ Returns @@ -408,4 +415,3 @@ class Guild(SlottedModel, Permissible): def create_channel(self, *args, **kwargs): return self.client.api.guilds_channels_create(self.id, *args, **kwargs) - From 2bac9f978105454a38f750d20004412ea11410e6 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 26 Jan 2017 17:26:46 -0800 Subject: [PATCH 74/91] Bump version to 0.0.7 --- disco/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/__init__.py b/disco/__init__.py index 0cdff65..262e0b7 100644 --- a/disco/__init__.py +++ b/disco/__init__.py @@ -1 +1 @@ -VERSION = '0.0.6' +VERSION = '0.0.7' From e42e9ebe6308a4bd91ce109b390316d74e0fd3c1 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 16 Feb 2017 11:19:57 -0800 Subject: [PATCH 75/91] add support for loading a level getter from a module/function path --- disco/bot/bot.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 7a8a433..5b664a6 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -140,6 +140,12 @@ class Bot(LoggingClass): if self.config.commands_allow_edit: self.client.events.on('MessageUpdate', self.on_message_update) + # If we have a level getter and its a string, try to load it + if isinstance(self.config.commands_level_getter, (str, unicode)): + mod, func = self.config.commands_level_getter.rsplit('.', 1) + mod = importlib.import_module(mod) + self.config.commands_level_getter = getattr(mod, func) + # Stores the last message for every single channel self.last_message_cache = {} From cc012cd27488d936cdd11ad8349f692b47a20357 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 16 Feb 2017 11:22:05 -0800 Subject: [PATCH 76/91] pass bot instance to commands_level_getter --- disco/bot/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 5b664a6..4d43a27 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -295,7 +295,7 @@ class Bot(LoggingClass): level = CommandLevels.DEFAULT if callable(self.config.commands_level_getter): - level = self.config.commands_level_getter(actor) + level = self.config.commands_level_getter(self, actor) else: if actor.id in self.config.levels: level = self.config.levels[actor.id] From c654f0ca55c3114f4f5aaae20a638684ddc4aa9f Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 16 Feb 2017 21:23:04 -0800 Subject: [PATCH 77/91] Fix MessageEmbed.set_footer --- disco/types/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/types/message.py b/disco/types/message.py index 7570b6a..96b083f 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -118,7 +118,7 @@ class MessageEmbed(SlottedModel): fields = ListField(MessageEmbedField) def set_footer(self, *args, **kwargs): - self.footer = MessageEmbedField(*args, **kwargs) + self.footer = MessageEmbedFooter(*args, **kwargs) def set_image(self, *args, **kwargs): self.image = MessageEmbedImage(*args, **kwargs) From 52b912abfdd304ffc4c8bb3f38104a7a957f0e44 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 18 Feb 2017 01:24:57 -0800 Subject: [PATCH 78/91] type changes --- disco/types/base.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/disco/types/base.py b/disco/types/base.py index d12de6e..670ab07 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -33,7 +33,7 @@ class ConversionError(Exception): def __init__(self, field, raw, e): super(ConversionError, self).__init__( 'Failed to convert `{}` (`{}`) to {}: {}'.format( - str(raw)[:144], field.src_name, field.deserializer, e)) + str(raw)[:144], field.src_name, field.true_type, e)) if six.PY3: self.__cause__ = e @@ -42,6 +42,7 @@ class ConversionError(Exception): class Field(object): def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, cast=None, **kwargs): # TODO: fix default bullshit + self.true_type = value_type self.src_name = alias self.dst_name = None self.ignore_dump = ignore_dump or [] @@ -114,7 +115,9 @@ class DictField(Field): default = HashMap def __init__(self, key_type, value_type=None, **kwargs): - super(DictField, self).__init__(None, **kwargs) + super(DictField, self).__init__({}, **kwargs) + self.true_key_type = key_type + self.true_value_type = value_type self.key_de = self.type_to_deserializer(key_type) self.value_de = self.type_to_deserializer(value_type or key_type) @@ -146,7 +149,7 @@ class AutoDictField(Field): default = HashMap def __init__(self, value_type, key, **kwargs): - super(AutoDictField, self).__init__(None, **kwargs) + super(AutoDictField, self).__init__({}, **kwargs) self.value_de = self.type_to_deserializer(value_type) self.key = key @@ -296,12 +299,19 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): obj = kwargs self.load(obj) + self.validate() + + def validate(self): + pass @property def _fields(self): return self.__class__._fields def load(self, obj, consume=False, skip=None): + return self.load_into(self, obj, consume, skip) + + def load_into(self, inst, obj, consume=False, skip=None): for name, field in six.iteritems(self._fields): should_skip = skip and name in skip @@ -313,16 +323,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): # If the field is unset/none, and we have a default we need to set it if (raw in (None, UNSET) or should_skip) and field.has_default(): default = field.default() if callable(field.default) else field.default - setattr(self, field.dst_name, default) + setattr(inst, field.dst_name, default) continue # Otherwise if the field is UNSET and has no default, skip conversion if raw is UNSET or should_skip: - setattr(self, field.dst_name, raw) + setattr(inst, field.dst_name, raw) continue value = field.try_convert(raw, self.client) - setattr(self, field.dst_name, value) + setattr(inst, field.dst_name, value) def update(self, other): for name in six.iterkeys(self._fields): From 94c33ce358de79c916f1c8f253b48e22cfe00405 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 18 Feb 2017 18:20:02 -0800 Subject: [PATCH 79/91] Cleanup VOICE_STATE_UPDATE fixes bug with cached channel property --- disco/state.py | 25 ++++++++++++------------- disco/types/voice.py | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/disco/state.py b/disco/state.py index ddca31c..75531f7 100644 --- a/disco/state.py +++ b/disco/state.py @@ -223,22 +223,21 @@ class State(object): del self.dms[event.channel.id] def on_voice_state_update(self, event): - # Happy path: we have the voice state and want to update/delete it - guild = self.guilds.get(event.state.guild_id) - if not guild: - return - - if event.state.session_id in guild.voice_states: + # Existing connection, we are either moving channels or disconnecting + if event.state.session_id in self.voice_states: + # Moving channels if event.state.channel_id: - guild.voice_states[event.state.session_id].update(event.state) + self.voice_states[event.state.session_id].update(event.state) + # Disconnection else: - del guild.voice_states[event.state.session_id] - - # Prevent a weird race where events come in before the guild_create (I think...) - if event.state.session_id in self.voice_states: - del self.voice_states[event.state.session_id] + if event.state.guild_id in self.guilds: + if event.state.session_id in self.guilds[event.state.guild_id].voice_states: + del self.guilds[event.state.guild_id].voice_states[event.state.session_id] + del self.voice_states[event.state.session_id] + # New connection elif event.state.channel_id: - guild.voice_states[event.state.session_id] = event.state + if event.state.guild_id in self.guilds: + self.guilds[event.state.guild_id].voice_states[event.state.session_id] = event.state self.voice_states[event.state.session_id] = event.state def on_guild_member_add(self, event): diff --git a/disco/types/voice.py b/disco/types/voice.py index 1647eb3..3d7cb32 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -17,7 +17,7 @@ class VoiceState(SlottedModel): def guild(self): return self.client.state.guilds.get(self.guild_id) - @cached_property + @property def channel(self): return self.client.state.channels.get(self.channel_id) From c7a017713a77c8977195c3d593e463ec59abf012 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 21 Feb 2017 16:41:47 -0800 Subject: [PATCH 80/91] Fix some python3 issues - (str, unicode) instead of six.string_types - binary/text conversion on None --- disco/bot/bot.py | 2 +- disco/types/base.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 4d43a27..ce8b4f3 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -141,7 +141,7 @@ class Bot(LoggingClass): self.client.events.on('MessageUpdate', self.on_message_update) # If we have a level getter and its a string, try to load it - if isinstance(self.config.commands_level_getter, (str, unicode)): + if isinstance(self.config.commands_level_getter, six.string_types): mod, func = self.config.commands_level_getter.rsplit('.', 1) mod = importlib.import_module(mod) self.config.commands_level_getter = getattr(mod, func) diff --git a/disco/types/base.py b/disco/types/base.py index 670ab07..ff9937f 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -207,6 +207,9 @@ def datetime(data): def text(obj): + if obj is None: + return None + if six.PY2: if isinstance(obj, str): return obj.decode('utf-8') @@ -216,6 +219,9 @@ def text(obj): def binary(obj): + if obj is None: + return None + if six.PY2: if isinstance(obj, str): return obj.decode('utf-8') From 534a15895be738f290655e14ed1d4f8036caa1c0 Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 21 Feb 2017 16:53:13 -0800 Subject: [PATCH 81/91] Fix Channel.delete_messages --- disco/types/channel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/disco/types/channel.py b/disco/types/channel.py index 664fd94..ee986be 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,5 +1,6 @@ import six +from six.moves import map from holster.enum import Enum from disco.util.snowflake import to_snowflake @@ -298,9 +299,9 @@ class Channel(SlottedModel, Permissible): List of messages (or message ids) to delete. All messages must originate from this channel. """ - messages = map(to_snowflake, messages) + message_ids = list(map(to_snowflake, messages)) - if not messages: + if not message_ids: return if self.can(self.client.state.me, Permissions.MANAGE_MESSAGES) and len(messages) > 2: From 713c1be01b8c97f690c5111f479e2549d28663c6 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 24 Feb 2017 12:02:30 -0800 Subject: [PATCH 82/91] Various fixes, remove lazy_datetime, etc --- disco/bot/plugin.py | 17 +++++++++++++++++ disco/gateway/events.py | 6 +++--- disco/types/base.py | 15 +-------------- disco/types/guild.py | 5 ++--- disco/types/invite.py | 4 ++-- disco/types/message.py | 8 ++++---- 6 files changed, 29 insertions(+), 26 deletions(-) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index a4ef0ea..aeed5b3 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -5,6 +5,7 @@ import inspect import weakref import functools +from gevent.event import AsyncResult from holster.emitter import Priority from disco.util.logging import LoggingClass @@ -208,6 +209,22 @@ class Plugin(LoggingClass, PluginDeco): def handle_exception(self, greenlet, event): pass + def wait_for_event(self, event_name, **kwargs): + result = AsyncResult() + listener = None + + def _event_callback(event): + for k, v in kwargs.items(): + if getattr(event, k) != v: + break + else: + listener.remove() + return result.set(event) + + listener = self.bot.client.events.on(event_name, _event_callback) + + return result + def spawn_wrap(self, spawner, method, *args, **kwargs): def wrapped(*args, **kwargs): self.ctx['plugin'] = self diff --git a/disco/gateway/events.py b/disco/gateway/events.py index fe5c669..3f6bbf1 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -9,7 +9,7 @@ from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role, GuildEmoji -from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, lazy_datetime +from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} @@ -244,7 +244,7 @@ class ChannelPinsUpdate(GatewayEvent): The time the last message was pinned. """ channel_id = Field(snowflake) - last_pin_timestamp = Field(lazy_datetime) + last_pin_timestamp = Field(datetime) @proxy(User) @@ -539,7 +539,7 @@ class TypingStart(GatewayEvent): """ channel_id = Field(snowflake) user_id = Field(snowflake) - timestamp = Field(lazy_datetime) + timestamp = Field(datetime) @wraps_model(VoiceState, alias='state') diff --git a/disco/types/base.py b/disco/types/base.py index ff9937f..1d1341a 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -177,26 +177,13 @@ def enum(typ): return _f -# TODO: make lazy -def lazy_datetime(data): +def datetime(data): if not data: return None if isinstance(data, int): return real_datetime.utcfromtimestamp(data) - for fmt in DATETIME_FORMATS: - try: - return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) - except (ValueError, TypeError): - continue - raise ValueError('Failed to conver `{}` to datetime'.format(data)) - - -def datetime(data): - if not data: - return None - for fmt in DATETIME_FORMATS: try: return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) diff --git a/disco/types/guild.py b/disco/types/guild.py index 0d2952e..f7f71c9 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -7,7 +7,7 @@ from disco.api.http import APIException from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.base import ( - SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum + SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum, datetime ) from disco.types.user import User, Presence from disco.types.voice import VoiceState @@ -21,7 +21,6 @@ VerificationLevel = Enum( LOW=1, MEDIUM=2, HIGH=3, - EXTREME=4, ) @@ -141,7 +140,7 @@ class GuildMember(SlottedModel): nick = Field(text) mute = Field(bool) deaf = Field(bool) - joined_at = Field(str) + joined_at = Field(datetime) roles = ListField(snowflake) def __str__(self): diff --git a/disco/types/invite.py b/disco/types/invite.py index 906a360..0cc4852 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,4 +1,4 @@ -from disco.types.base import SlottedModel, Field, lazy_datetime +from disco.types.base import SlottedModel, Field, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel @@ -37,7 +37,7 @@ class Invite(SlottedModel): max_uses = Field(int) uses = Field(int) temporary = Field(bool) - created_at = Field(lazy_datetime) + created_at = Field(datetime) @classmethod def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False): diff --git a/disco/types/message.py b/disco/types/message.py index 96b083f..821c9c3 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -7,7 +7,7 @@ from holster.enum import Enum from disco.types.base import ( SlottedModel, Field, ListField, AutoDictField, snowflake, text, - lazy_datetime, enum + datetime, enum ) from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property @@ -108,7 +108,7 @@ class MessageEmbed(SlottedModel): type = Field(str, default='rich') description = Field(text) url = Field(text) - timestamp = Field(lazy_datetime) + timestamp = Field(datetime) color = Field(int) footer = Field(MessageEmbedFooter) image = Field(MessageEmbedImage) @@ -210,8 +210,8 @@ class Message(SlottedModel): author = Field(User) content = Field(text) nonce = Field(snowflake) - timestamp = Field(lazy_datetime) - edited_timestamp = Field(lazy_datetime) + timestamp = Field(datetime) + edited_timestamp = Field(datetime) tts = Field(bool) mention_everyone = Field(bool) pinned = Field(bool) From c5848dbe8b66295598b4e6cad1c5ec8f5cc3a5fb Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 13 Mar 2017 12:26:19 -0700 Subject: [PATCH 83/91] Fix Channel.delete_messages --- disco/types/channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/types/channel.py b/disco/types/channel.py index ee986be..4c91d3c 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -305,7 +305,7 @@ class Channel(SlottedModel, Permissible): return if self.can(self.client.state.me, Permissions.MANAGE_MESSAGES) and len(messages) > 2: - for chunk in chunks(messages, 100): + for chunk in chunks(message_ids, 100): self.client.api.channels_messages_delete_bulk(self.id, chunk) else: for msg in messages: From e5a97a3c334b28e7da1866ffe188792ff8a36e8e Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 15 Mar 2017 17:06:24 -0700 Subject: [PATCH 84/91] Fix command_matches_re not being case insensitive --- disco/bot/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disco/bot/bot.py b/disco/bot/bot.py index ce8b4f3..7693f50 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -225,7 +225,7 @@ class Bot(LoggingClass): commands = list(self.commands) re_str = '|'.join(command.regex for command in commands) if re_str: - self.command_matches_re = re.compile(re_str) + self.command_matches_re = re.compile(re_str, re.I) else: self.command_matches_re = None From 12031c3d7315bae62c24fa41a7fa61d88f3b0c46 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 15 Mar 2017 23:31:46 -0700 Subject: [PATCH 85/91] Add GuildMember.modify --- disco/api/client.py | 2 +- disco/types/guild.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/disco/api/client.py b/disco/api/client.py index c0f90ea..0c00539 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -240,7 +240,7 @@ class APIClient(LoggingClass): return GuildMember.create(self.client, r.json(), guild_id=guild) def guilds_members_modify(self, guild, member, **kwargs): - self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) + self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=optional(**kwargs)) def guilds_members_roles_add(self, guild, member, role): self.http(Routes.GUILDS_MEMBERS_ROLES_ADD, dict(guild=guild, member=member, role=role)) diff --git a/disco/types/guild.py b/disco/types/guild.py index f7f71c9..b478cd9 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -200,6 +200,9 @@ class GuildMember(SlottedModel): else: self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') + def modify(self, **kwargs): + self.client.api.guilds_members_modify(self.guild.id, self.user.id, **kwargs) + def add_role(self, role): self.client.api.guilds_members_roles_add(self.guild.id, self.user.id, to_snowflake(role)) From a8793c869f323020997f6a680ac194fc8fc94a28 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 16 Mar 2017 19:05:50 -0700 Subject: [PATCH 86/91] Add duration arg type, fix GuildMember.remove_role --- disco/bot/parser.py | 7 ++++++- disco/types/guild.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 16ee3f2..722abe6 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -2,7 +2,6 @@ import re import six import copy - # Regex which splits out argument parts PARTS_RE = re.compile('(\<|\[|\{)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\]|\})') @@ -16,6 +15,12 @@ TYPE_MAP = { 'snowflake': lambda ctx, data: int(data), } +try: + import dateparser + TYPE_MAP['duration'] = lambda ctx, data: dateparser.parse(data, settings={'TIMEZONE': 'UTC'}) +except ImportError: + pass + def to_bool(ctx, data): if data in BOOL_OPTS: diff --git a/disco/types/guild.py b/disco/types/guild.py index b478cd9..f93a88f 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -207,7 +207,7 @@ class GuildMember(SlottedModel): self.client.api.guilds_members_roles_add(self.guild.id, self.user.id, to_snowflake(role)) def remove_role(self, role): - self.clients.api.guilds_members_roles_remove(self.guild.id, self.user.id, to_snowflake(role)) + self.client.api.guilds_members_roles_remove(self.guild.id, self.user.id, to_snowflake(role)) @cached_property def owner(self): From 66433be371eacb331c63ae56acdbf8a2e9adfbb2 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 18 Mar 2017 03:58:04 -0700 Subject: [PATCH 87/91] Add default avatars --- disco/types/user.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/disco/types/user.py b/disco/types/user.py index d4c32cd..3192abc 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -2,6 +2,14 @@ from holster.enum import Enum from disco.types.base import SlottedModel, Field, snowflake, text, binary, with_equality, with_hash +DefaultAvatars = Enum( + BLURPLE=0, + GREY=1, + GREEN=2, + ORANGE=3, + RED=4, +) + class User(SlottedModel, with_equality('id'), with_hash('id')): id = Field(snowflake) @@ -16,7 +24,7 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): def get_avatar_url(self, fmt='webp', size=1024): if not self.avatar: - return None + return 'https://cdn.discordapp.com/embed/avatars/{}.png'.format(self.default_avatar.value) return 'https://cdn.discordapp.com/avatars/{}/{}.{}?size={}'.format( self.id, @@ -25,6 +33,10 @@ class User(SlottedModel, with_equality('id'), with_hash('id')): size ) + @property + def default_avatar(self): + return DefaultAvatars[int(self.discriminator) % len(DefaultAvatars.attrs)] + @property def avatar_url(self): return self.get_avatar_url() From 1a9b0d2e764a85e780eb204229e9a2c804f43cfa Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 18 Mar 2017 16:16:15 -0700 Subject: [PATCH 88/91] Store raw argument string on the Command object --- disco/bot/command.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/disco/bot/command.py b/disco/bot/command.py index 80f6929..3d0398a 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -129,6 +129,7 @@ class Command(object): self.triggers = [trigger] self.dispatch_func = None + self.raw_args = None self.args = None self.level = None self.group = None @@ -169,6 +170,7 @@ class Command(object): def resolve_guild(ctx, gid): return ctx.msg.client.state.guilds.get(gid) + self.raw_args = args self.args = ArgumentSet.from_string(args or '', { 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True), 'role': self.mention_type([resolve_role], ROLE_MENTION_RE), From b00e82da634ea933066a2f65018c6059caa24e4a Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 18 Mar 2017 21:33:08 -0700 Subject: [PATCH 89/91] Add MessageReactionRemoveAll utils, Message.get_reactors --- disco/gateway/events.py | 8 ++++++++ disco/types/message.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 3f6bbf1..5033b03 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -658,3 +658,11 @@ class MessageReactionRemoveAll(GatewayEvent): """ channel_id = Field(snowflake) message_id = Field(snowflake) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild diff --git a/disco/types/message.py b/disco/types/message.py index 821c9c3..9b33d11 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -299,6 +299,21 @@ class Message(SlottedModel): """ return self.client.api.channels_messages_delete(self.channel_id, self.id) + def get_reactors(self, emoji): + """ + Returns an list of users who reacted to this message with the given emoji. + + Returns + ------- + list(:class:`User`) + The users who reacted. + """ + return self.client.api.channels_messages_reactions_get( + self.channel_id, + self.id, + emoji + ) + def create_reaction(self, emoji): if isinstance(emoji, Emoji): emoji = emoji.to_string() From 19201517fdaca6429cafb5d1eb4ef6edacb5e76f Mon Sep 17 00:00:00 2001 From: Andrei Date: Tue, 28 Mar 2017 17:36:06 -0700 Subject: [PATCH 90/91] Fix smashing important state on GUILD_UPDATEs GUILD_UPDATES are cool and special and of course they are partial. Although this is logical, our type/models autoinitialize some fields by default (which is actually fairly sane). However when this happens, we smash these new blank mappings over the previously updated state. Instead we should just ignore fields that don't come in GUILD_UPDATEs, and save our state. --- disco/state.py | 7 ++++++- disco/types/base.py | 5 ++++- disco/types/channel.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/disco/state.py b/disco/state.py index 75531f7..689ab62 100644 --- a/disco/state.py +++ b/disco/state.py @@ -193,7 +193,12 @@ class State(object): event.guild.sync() def on_guild_update(self, event): - self.guilds[event.guild.id].update(event.guild) + self.guilds[event.guild.id].update(event.guild, ignored=[ + 'channels', + 'members', + 'voice_states', + 'presences' + ]) def on_guild_delete(self, event): if event.id in self.guilds: diff --git a/disco/types/base.py b/disco/types/base.py index 1d1341a..c3a061b 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -327,8 +327,11 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): value = field.try_convert(raw, self.client) setattr(inst, field.dst_name, value) - def update(self, other): + def update(self, other, ignored=None): for name in six.iterkeys(self._fields): + if ignored and name in ignored: + continue + if hasattr(other, name) and not getattr(other, name) is UNSET: setattr(self, name, getattr(other, name)) diff --git a/disco/types/channel.py b/disco/types/channel.py index 4c91d3c..311ca5c 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -149,8 +149,8 @@ class Channel(SlottedModel, Permissible): if not self.guild_id: return Permissions.ADMINISTRATOR - member = self.guild.members.get(user.id) - base = self.guild.get_permissions(user) + member = self.guild.get_member(user) + base = self.guild.get_permissions(member) for ow in six.itervalues(self.overwrites): if ow.id != user.id and ow.id not in member.roles: From b341ae9aee9a6448c518ef5bdc91efe8a385beed Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 3 Apr 2017 17:50:13 -0700 Subject: [PATCH 91/91] Fix missing \n --- disco/util/logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/disco/util/logging.py b/disco/util/logging.py index 68af8a8..75e9229 100644 --- a/disco/util/logging.py +++ b/disco/util/logging.py @@ -9,6 +9,7 @@ LEVEL_OVERRIDES = { LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s' + def setup_logging(**kwargs): kwargs.setdefault('format', LOG_FORMAT)