diff --git a/README.md b/README.md index 027229d..ae77c31 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Disco was built to run both as a generic-use library, and a standalone bot toolk |requests[security]|adds packages for a proper SSL implementation| |ujson|faster json parser, improves performance| |erlpack|ETF parser, only Python 2.x, run with the --encoder=etf flag| +|gipc|Gevent IPC, required for autosharding| ## Examples @@ -48,7 +49,7 @@ class SimplePlugin(Plugin): Using the default bot configuration, we can now run this script like so: -`python -m disco.cli --token="MY_DISCORD_TOKEN" --bot --plugin simpleplugin` +`python -m disco.cli --token="MY_DISCORD_TOKEN" --run-bot --plugin simpleplugin` And commands can be triggered by mentioning the bot (configued by the BotConfig.command\_require\_mention flag): diff --git a/disco/__init__.py b/disco/__init__.py index 6e83b38..262e0b7 100644 --- a/disco/__init__.py +++ b/disco/__init__.py @@ -1 +1 @@ -VERSION = '0.0.5' +VERSION = '0.0.7' diff --git a/disco/api/client.py b/disco/api/client.py index ac83dea..0c00539 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -1,11 +1,12 @@ import six +import json from disco.api.http import Routes, HTTPClient from disco.util.logging import LoggingClass from disco.types.user import User from disco.types.message import Message -from disco.types.guild import Guild, GuildMember, Role +from disco.types.guild import Guild, GuildMember, GuildBan, Role, GuildEmoji from disco.types.channel import Channel from disco.types.invite import Invite from disco.types.webhook import Webhook @@ -23,18 +24,40 @@ def optional(**kwargs): class APIClient(LoggingClass): """ - An abstraction over the :class:`disco.api.http.HTTPClient` that composes requests, and fits - the models with the returned data. + An abstraction over a :class:`disco.api.http.HTTPClient`, which composes + requests from provided data, and fits models with the returned data. The APIClient + is the only path to the API used within models/other interfaces, and it's + the recommended path for all third-party users/implementations. + + Args + ---- + token : str + The Discord authentication token (without prefixes) to be used for all + HTTP requests. + client : Optional[:class:`disco.client.Client`] + The Disco client this APIClient is a member of. This is used when constructing + and fitting models from response data. + + Attributes + ---------- + client : Optional[:class:`disco.client.Client`] + The Disco client this APIClient is a member of. + http : :class:`disco.http.HTTPClient` + The HTTPClient this APIClient uses for all requests. """ - def __init__(self, client): + def __init__(self, token, client=None): super(APIClient, self).__init__() self.client = client - self.http = HTTPClient(self.client.config.token) + self.http = HTTPClient(token) - def gateway(self, version, encoding): + def gateway_get(self): data = self.http(Routes.GATEWAY_GET).json() - return data['url'] + '?v={}&encoding={}'.format(version, encoding) + return data + + def gateway_bot_get(self): + data = self.http(Routes.GATEWAY_BOT_GET).json() + return data def channels_get(self, channel): r = self.http(Routes.CHANNELS_GET, dict(channel=channel)) @@ -48,6 +71,9 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_DELETE, dict(channel=channel)) return Channel.create(self.client, r.json()) + def channels_typing(self, channel): + self.http(Routes.CHANNELS_TYPING, dict(channel=channel)) + def channels_messages_list(self, channel, around=None, before=None, after=None, limit=50): r = self.http(Routes.CHANNELS_MESSAGES_LIST, dict(channel=channel), params=optional( around=around, @@ -62,19 +88,36 @@ class APIClient(LoggingClass): r = self.http(Routes.CHANNELS_MESSAGES_GET, dict(channel=channel, message=message)) return Message.create(self.client, r.json()) - def channels_messages_create(self, channel, content, nonce=None, tts=False): - r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json={ + def channels_messages_create(self, channel, content, nonce=None, tts=False, attachment=None, embed=None): + payload = { 'content': content, 'nonce': nonce, 'tts': tts, - }) + } + + if embed: + payload['embed'] = embed.to_dict() + + if attachment: + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), data={'payload_json': json.dumps(payload)}, files={ + 'file': (attachment[0], attachment[1]) + }) + else: + r = self.http(Routes.CHANNELS_MESSAGES_CREATE, dict(channel=channel), json=payload) return Message.create(self.client, r.json()) - def channels_messages_modify(self, channel, message, content): + def channels_messages_modify(self, channel, message, content, embed=None): + payload = { + 'content': content, + } + + if embed: + payload['embed'] = embed.to_dict() + r = self.http(Routes.CHANNELS_MESSAGES_MODIFY, - dict(channel=channel, message=message), - json={'content': content}) + dict(channel=channel, message=message), + json=payload) return Message.create(self.client, r.json()) def channels_messages_delete(self, channel, message): @@ -83,6 +126,23 @@ class APIClient(LoggingClass): def channels_messages_delete_bulk(self, channel, messages): self.http(Routes.CHANNELS_MESSAGES_DELETE_BULK, dict(channel=channel), json={'messages': messages}) + def channels_messages_reactions_get(self, channel, message, emoji): + r = self.http(Routes.CHANNELS_MESSAGES_REACTIONS_GET, dict(channel=channel, message=message, emoji=emoji)) + return User.create_map(self.client, r.json()) + + def channels_messages_reactions_create(self, channel, message, emoji): + self.http(Routes.CHANNELS_MESSAGES_REACTIONS_CREATE, dict(channel=channel, message=message, emoji=emoji)) + + def channels_messages_reactions_delete(self, channel, message, emoji, user=None): + route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_ME + obj = dict(channel=channel, message=message, emoji=emoji) + + if user: + route = Routes.CHANNELS_MESSAGES_REACTIONS_DELETE_USER + obj['user'] = user + + self.http(route, obj) + def channels_permissions_modify(self, channel, permission, allow, deny, typ): self.http(Routes.CHANNELS_PERMISSIONS_MODIFY, dict(channel=channel, permission=permission), json={ 'allow': allow, @@ -141,10 +201,28 @@ class APIClient(LoggingClass): def guilds_channels_list(self, guild): r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) - return Channel.create_map(self.client, r.json(), guild_id=guild) - - def guilds_channels_create(self, guild, **kwargs): - r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) + return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild) + + def guilds_channels_create(self, guild, name, channel_type, bitrate=None, user_limit=None, permission_overwrites=[]): + payload = { + 'name': name, + 'channel_type': channel_type, + 'permission_overwrites': [i.to_dict() for i in permission_overwrites], + } + + if channel_type == 'text': + pass + elif channel_type == 'voice': + if bitrate is not None: + payload['bitrate'] = bitrate + + if user_limit is not None: + payload['user_limit'] = user_limit + else: + # TODO: better error here? + raise Exception('Invalid channel type: {}'.format(channel_type)) + + r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=payload) return Channel.create(self.client, r.json(), guild_id=guild) def guilds_channels_modify(self, guild, channel, position): @@ -155,21 +233,30 @@ class APIClient(LoggingClass): def guilds_members_list(self, guild): r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild)) - return GuildMember.create_map(self.client, r.json(), guild_id=guild) + return GuildMember.create_hash(self.client, 'id', r.json(), guild_id=guild) def guilds_members_get(self, guild, member): r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member)) return GuildMember.create(self.client, r.json(), guild_id=guild) def guilds_members_modify(self, guild, member, **kwargs): - self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) + self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=optional(**kwargs)) + + def guilds_members_roles_add(self, guild, member, role): + self.http(Routes.GUILDS_MEMBERS_ROLES_ADD, dict(guild=guild, member=member, role=role)) + + def guilds_members_roles_remove(self, guild, member, role): + self.http(Routes.GUILDS_MEMBERS_ROLES_REMOVE, dict(guild=guild, member=member, role=role)) + + def guilds_members_me_nick(self, guild, nick): + self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick}) def guilds_members_kick(self, guild, member): self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member)) def guilds_bans_list(self, guild): r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild)) - return User.create_map(self.client, r.json()) + return GuildBan.create_hash(self.client, 'user.id', r.json()) def guilds_bans_create(self, guild, user, delete_message_days): self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ @@ -202,6 +289,28 @@ class APIClient(LoggingClass): r = self.http(Routes.GUILDS_WEBHOOKS_LIST, dict(guild=guild)) return Webhook.create_map(self.client, r.json()) + def guilds_emojis_list(self, guild): + r = self.http(Routes.GUILDS_EMOJIS_LIST, dict(guild=guild)) + return GuildEmoji.create_map(self.client, r.json()) + + def guilds_emojis_create(self, guild, **kwargs): + r = self.http(Routes.GUILDS_EMOJIS_CREATE, dict(guild=guild), json=kwargs) + return GuildEmoji.create(self.client, r.json()) + + def guilds_emojis_modify(self, guild, emoji, **kwargs): + r = self.http(Routes.GUILDS_EMOJIS_MODIFY, dict(guild=guild, emoji=emoji), json=kwargs) + return GuildEmoji.create(self.client, r.json()) + + def guilds_emojis_delete(self, guild, emoji): + self.http(Routes.GUILDS_EMOJIS_DELETE, dict(guild=guild, emoji=emoji)) + + def users_me_get(self): + return User.create(self.client, self.http(Routes.USERS_ME_GET).json()) + + def users_me_patch(self, payload): + r = self.http(Routes.USERS_ME_PATCH, json=payload) + return User.create(self.client, r.json()) + def invites_get(self, invite): r = self.http(Routes.INVITES_GET, dict(invite=invite)) return Invite.create(self.client, r.json()) @@ -236,7 +345,7 @@ class APIClient(LoggingClass): return Webhook.create(self.client, r.json()) def webhooks_token_delete(self, webhook, token): - self.http(Routes.WEBHOOKS_TOKEN_DLEETE, dict(webhook=webhook, token=token)) + self.http(Routes.WEBHOOKS_TOKEN_DELETE, dict(webhook=webhook, token=token)) def webhooks_token_execute(self, webhook, token, data, wait=False): obj = self.http( diff --git a/disco/api/http.py b/disco/api/http.py index c1930bd..088b69e 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -2,9 +2,12 @@ import requests import random import gevent import six +import sys from holster.enum import Enum +from disco import VERSION as disco_version +from requests import __version__ as requests_version from disco.util.logging import LoggingClass from disco.api.ratelimit import RateLimiter @@ -18,6 +21,12 @@ HTTPMethod = Enum( ) +def to_bytes(obj): + if isinstance(obj, six.text_type): + return obj.encode('utf-8') + return obj + + class Routes(object): """ Simple Python object-enum of all method/url route combinations available to @@ -25,18 +34,25 @@ class Routes(object): """ # Gateway GATEWAY_GET = (HTTPMethod.GET, '/gateway') + GATEWAY_BOT_GET = (HTTPMethod.GET, '/gateway/bot') # Channels CHANNELS = '/channels/{channel}' CHANNELS_GET = (HTTPMethod.GET, CHANNELS) CHANNELS_MODIFY = (HTTPMethod.PATCH, CHANNELS) CHANNELS_DELETE = (HTTPMethod.DELETE, CHANNELS) + CHANNELS_TYPING = (HTTPMethod.POST, CHANNELS + '/typing') CHANNELS_MESSAGES_LIST = (HTTPMethod.GET, CHANNELS + '/messages') CHANNELS_MESSAGES_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_CREATE = (HTTPMethod.POST, CHANNELS + '/messages') CHANNELS_MESSAGES_MODIFY = (HTTPMethod.PATCH, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_DELETE = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}') CHANNELS_MESSAGES_DELETE_BULK = (HTTPMethod.POST, CHANNELS + '/messages/bulk_delete') + CHANNELS_MESSAGES_REACTIONS_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}/reactions/{emoji}') + CHANNELS_MESSAGES_REACTIONS_CREATE = (HTTPMethod.PUT, CHANNELS + '/messages/{message}/reactions/{emoji}/@me') + CHANNELS_MESSAGES_REACTIONS_DELETE_ME = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/@me') + CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE, + CHANNELS + '/messages/{message}/reactions/{emoji}/{user}') CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}') CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}') CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites') @@ -58,6 +74,9 @@ class Routes(object): GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') + GUILDS_MEMBERS_ROLES_ADD = (HTTPMethod.PUT, GUILDS + '/members/{member}/roles/{role}') + GUILDS_MEMBERS_ROLES_REMOVE = (HTTPMethod.DELETE, GUILDS + '/members/{member}/roles/{role}') + GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') @@ -79,6 +98,10 @@ class Routes(object): GUILDS_EMBED_GET = (HTTPMethod.GET, GUILDS + '/embed') GUILDS_EMBED_MODIFY = (HTTPMethod.PATCH, GUILDS + '/embed') GUILDS_WEBHOOKS_LIST = (HTTPMethod.GET, GUILDS + '/webhooks') + GUILDS_EMOJIS_LIST = (HTTPMethod.GET, GUILDS + '/emojis') + GUILDS_EMOJIS_CREATE = (HTTPMethod.POST, GUILDS + '/emojis') + GUILDS_EMOJIS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/emojis/{emoji}') + GUILDS_EMOJIS_DELETE = (HTTPMethod.DELETE, GUILDS + '/emojis/{emoji}') # Users USERS = '/users' @@ -111,14 +134,39 @@ class APIException(Exception): """ Exception thrown when an HTTP-client level error occurs. Usually this will be a non-success status-code, or a transient network issue. + + Attributes + ---------- + status_code : int + The status code returned by the API for the request that triggered this + error. """ - def __init__(self, msg, status_code=0, content=None): - self.status_code = status_code - self.content = content - self.msg = msg + def __init__(self, response, retries=None): + self.response = response + self.retries = retries + + self.code = 0 + self.msg = 'Request Failed ({})'.format(response.status_code) - if self.status_code: - self.msg += ' code: {}'.format(status_code) + if self.retries: + self.msg += " after {} retries".format(self.retries) + + # Try to decode JSON, and extract params + try: + data = self.response.json() + + if 'code' in data: + self.code = data['code'] + self.msg = data['message'] + elif len(data) == 1: + key, value = list(data.items())[0] + self.msg = 'Request Failed: {}: {}'.format(key, ', '.join(value)) + except ValueError: + pass + + # DEPRECATED: left for backwards compat + self.status_code = response.status_code + self.content = response.content super(APIException, self).__init__(self.msg) @@ -134,9 +182,18 @@ class HTTPClient(LoggingClass): def __init__(self, token): super(HTTPClient, self).__init__() + py_version = '{}.{}.{}'.format( + sys.version_info.major, + sys.version_info.minor, + sys.version_info.micro) + self.limiter = RateLimiter() self.headers = { 'Authorization': 'Bot ' + token, + 'User-Agent': 'DiscordBot (https://github.com/b1naryth1ef/disco {}) Python/{} requests/{}'.format( + disco_version, + py_version, + requests_version), } def __call__(self, route, args=None, **kwargs): @@ -182,7 +239,8 @@ class HTTPClient(LoggingClass): kwargs['headers'] = self.headers # Build the bucket URL - filtered = {k: (v if v in ('guild', 'channel') else '') for k, v in six.iteritems(args)} + args = {k: to_bytes(v) for k, v in six.iteritems(args)} + filtered = {k: (v if k in ('guild', 'channel') else '') for k, v in six.iteritems(args)} bucket = (route[0].value, route[1].format(**filtered)) # Possibly wait if we're rate limited @@ -190,6 +248,7 @@ class HTTPClient(LoggingClass): # Make the actual request url = self.BASE_URL + route[1].format(**args) + self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params')) r = requests.request(route[0].value, url, **kwargs) # Update rate limiter @@ -198,17 +257,18 @@ class HTTPClient(LoggingClass): # If we got a success status code, just return the data if r.status_code < 400: return r - elif r.status_code != 429 and 400 < r.status_code < 500: - raise APIException('Request failed', r.status_code, r.content) + elif r.status_code != 429 and 400 <= r.status_code < 500: + raise APIException(r) else: if r.status_code == 429: - self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync') + self.log.warning( + 'Request responded w/ 429, retrying (but this should not happen, check your clock sync') # If we hit the max retries, throw an error retry += 1 if retry > self.MAX_RETRIES: self.log.error('Failing request, hit max retries') - raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content) + raise APIException(r, retries=self.MAX_RETRIES) backoff = self.random_backoff() self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format( diff --git a/disco/api/ratelimit.py b/disco/api/ratelimit.py index 420d6f3..054c8cf 100644 --- a/disco/api/ratelimit.py +++ b/disco/api/ratelimit.py @@ -1,8 +1,10 @@ import time import gevent +from disco.util.logging import LoggingClass -class RouteState(object): + +class RouteState(LoggingClass): """ An object which stores ratelimit state for a given method/url route combination (as specified in :class:`disco.api.http.Routes`). @@ -36,10 +38,13 @@ class RouteState(object): self.update(response) + def __repr__(self): + return ''.format(' '.join(self.route)) + @property def chilled(self): """ - Whether this route is currently being cooldown (aka waiting until reset_time) + Whether this route is currently being cooldown (aka waiting until reset_time). """ return self.event is not None @@ -69,7 +74,7 @@ class RouteState(object): def wait(self, timeout=None): """ - Waits until this route is no longer under a cooldown + Waits until this route is no longer under a cooldown. Parameters ---------- @@ -80,24 +85,26 @@ class RouteState(object): Returns ------- bool - False if the timeout period expired before the cooldown was finished + False if the timeout period expired before the cooldown was finished. """ return self.event.wait(timeout) def cooldown(self): """ - Waits for the current route to be cooled-down (aka waiting until reset time) + Waits for the current route to be cooled-down (aka waiting until reset time). """ if self.reset_time - time.time() < 0: raise Exception('Cannot cooldown for negative time period; check clock sync') self.event = gevent.event.Event() - gevent.sleep((self.reset_time - time.time()) + .5) + delay = (self.reset_time - time.time()) + .5 + self.log.debug('Cooling down bucket %s for %s seconds', self, delay) + gevent.sleep(delay) self.event.set() self.event = None -class RateLimiter(object): +class RateLimiter(LoggingClass): """ A in-memory store of ratelimit states for all routes we've ever called. diff --git a/disco/bot/bot.py b/disco/bot/bot.py index 6d5f250..7693f50 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -12,6 +12,7 @@ from disco.bot.plugin import Plugin from disco.bot.command import CommandEvent, CommandLevels from disco.bot.storage import Storage from disco.util.config import Config +from disco.util.logging import LoggingClass from disco.util.serializer import Serializer @@ -64,7 +65,7 @@ class BotConfig(Config): The directory plugin configuration is located within. """ levels = {} - plugins = [] + plugin_config = {} commands_enabled = True commands_require_mention = True @@ -88,7 +89,7 @@ class BotConfig(Config): storage_config = {} -class Bot(object): +class Bot(LoggingClass): """ Disco's implementation of a simple but extendable Discord bot. Bots consist of a set of plugins, and a Disco client. @@ -114,6 +115,9 @@ class Bot(object): self.client = client self.config = config or BotConfig() + # Shard manager + self.shards = None + # The context carries information about events in a threadlocal storage self.ctx = ThreadLocal() @@ -122,6 +126,7 @@ class Bot(object): if self.config.storage_enabled: self.storage = Storage(self.ctx, self.config.from_prefix('storage')) + # If the manhole is enabled, add this bot as a local if self.client.config.manhole_enable: self.client.manhole_locals['bot'] = self @@ -135,6 +140,12 @@ class Bot(object): if self.config.commands_allow_edit: self.client.events.on('MessageUpdate', self.on_message_update) + # If we have a level getter and its a string, try to load it + if isinstance(self.config.commands_level_getter, six.string_types): + mod, func = self.config.commands_level_getter.rsplit('.', 1) + mod = importlib.import_module(mod) + self.config.commands_level_getter = getattr(mod, func) + # Stores the last message for every single channel self.last_message_cache = {} @@ -173,10 +184,10 @@ class Bot(object): @property def commands(self): """ - Generator of all commands this bots plugins have defined + Generator of all commands this bots plugins have defined. """ for plugin in six.itervalues(self.plugins): - for command in six.itervalues(plugin.commands): + for command in plugin.commands: yield command def recompute(self): @@ -190,7 +201,7 @@ class Bot(object): def compute_group_abbrev(self): """ - Computes all possible abbreviations for a command grouping + Computes all possible abbreviations for a command grouping. """ self.group_abbrev = {} groups = set(command.group for command in self.commands if command.group) @@ -199,7 +210,7 @@ class Bot(object): grp = group while grp: # If the group already exists, means someone else thought they - # could use it so we need to + # could use it so we need yank it from them (and not use it) if grp in list(six.itervalues(self.group_abbrev)): self.group_abbrev = {k: v for k, v in six.iteritems(self.group_abbrev) if v != grp} else: @@ -211,13 +222,14 @@ class Bot(object): """ Computes a single regex which matches all possible command combinations. """ - re_str = '|'.join(command.regex for command in self.commands) + commands = list(self.commands) + re_str = '|'.join(command.regex for command in commands) if re_str: - self.command_matches_re = re.compile(re_str) + self.command_matches_re = re.compile(re_str, re.I) else: self.command_matches_re = None - def get_commands_for_message(self, msg): + def get_commands_for_message(self, require_mention, mention_rules, prefix, msg): """ Generator of all commands that a given message object triggers, based on the bots plugins and configuration. @@ -234,19 +246,19 @@ class Bot(object): """ content = msg.content - if self.config.commands_require_mention: + if require_mention: mention_direct = msg.is_mentioned(self.client.state.me) mention_everyone = msg.mention_everyone mention_roles = [] if msg.guild: mention_roles = list(filter(lambda r: msg.is_mentioned(r), - msg.guild.get_member(self.client.state.me).roles)) + msg.guild.get_member(self.client.state.me).roles)) if not any(( - self.config.commands_mention_rules['user'] and mention_direct, - self.config.commands_mention_rules['everyone'] and mention_everyone, - self.config.commands_mention_rules['role'] and any(mention_roles), + mention_rules.get('user', True) and mention_direct, + mention_rules.get('everyone', False) and mention_everyone, + mention_rules.get('role', False) and any(mention_roles), msg.channel.is_dm )): raise StopIteration @@ -262,14 +274,14 @@ class Bot(object): content = content.replace('@everyone', '', 1) else: for role in mention_roles: - content = content.replace(role.mention, '', 1) + content = content.replace('<@{}>'.format(role), '', 1) content = content.lstrip() - if self.config.commands_prefix and not content.startswith(self.config.commands_prefix): + if prefix and not content.startswith(prefix): raise StopIteration else: - content = content[len(self.config.commands_prefix):] + content = content[len(prefix):] if not self.command_matches_re or not self.command_matches_re.match(content): raise StopIteration @@ -283,7 +295,7 @@ class Bot(object): level = CommandLevels.DEFAULT if callable(self.config.commands_level_getter): - level = self.config.commands_level_getter(actor) + level = self.config.commands_level_getter(self, actor) else: if actor.id in self.config.levels: level = self.config.levels[actor.id] @@ -320,19 +332,24 @@ class Bot(object): bool whether any commands where successfully triggered by the message """ - commands = list(self.get_commands_for_message(msg)) + commands = list(self.get_commands_for_message( + self.config.commands_require_mention, + self.config.commands_mention_rules, + self.config.commands_prefix, + msg + )) - if len(commands): - result = False - for command, match in commands: - if not self.check_command_permissions(command, msg): - continue + if not len(commands): + return False - if command.plugin.execute(CommandEvent(command, msg, match)): - result = True - return result + result = False + for command, match in commands: + if not self.check_command_permissions(command, msg): + continue - return False + if command.plugin.execute(CommandEvent(command, msg, match)): + result = True + return result def on_message_create(self, event): if event.message.author.id == self.client.state.me.id: @@ -356,7 +373,7 @@ class Bot(object): self.last_message_cache[msg.channel_id] = (msg, triggered) - def add_plugin(self, cls, config=None): + def add_plugin(self, cls, config=None, ctx=None): """ Adds and loads a plugin, based on its class. @@ -366,8 +383,12 @@ class Bot(object): Plugin class to initialize and load. config : Optional The configuration to load the plugin with. + ctx : Optional[dict] + Context (previous state) to pass the plugin. Usually used along w/ + unload. """ if cls.__name__ in self.plugins: + self.log.warning('Attempted to add already added plugin %s', cls.__name__) raise Exception('Cannot add already added plugin: {}'.format(cls.__name__)) if not config: @@ -376,9 +397,10 @@ class Bot(object): else: config = self.load_plugin_config(cls) - self.plugins[cls.__name__] = cls(self, config) - self.plugins[cls.__name__].load() + self.ctx['plugin'] = self.plugins[cls.__name__] = cls(self, config) + self.plugins[cls.__name__].load(ctx or {}) self.recompute() + self.ctx.drop() def rmv_plugin(self, cls): """ @@ -392,9 +414,11 @@ class Bot(object): if cls.__name__ not in self.plugins: raise Exception('Cannot remove non-existant plugin: {}'.format(cls.__name__)) - self.plugins[cls.__name__].unload() + ctx = {} + self.plugins[cls.__name__].unload(ctx) del self.plugins[cls.__name__] self.recompute() + return ctx def reload_plugin(self, cls): """ @@ -402,13 +426,13 @@ class Bot(object): """ config = self.plugins[cls.__name__].config - self.rmv_plugin(cls) + ctx = self.rmv_plugin(cls) module = reload_module(inspect.getmodule(cls)) - self.add_plugin(getattr(module, cls.__name__), config) + self.add_plugin(getattr(module, cls.__name__), config, ctx) def run_forever(self): """ - Runs this bots core loop forever + Runs this bots core loop forever. """ self.client.run_forever() @@ -416,12 +440,14 @@ class Bot(object): """ Adds and loads a plugin, based on its module path. """ - + self.log.info('Adding plugin module at path "%s"', path) mod = importlib.import_module(path) loaded = False for entry in map(lambda i: getattr(mod, i), dir(mod)): if inspect.isclass(entry) and issubclass(entry, Plugin) and not entry == Plugin: + if getattr(entry, '_shallow', False) and Plugin in entry.__bases__: + continue loaded = True self.add_plugin(entry, config) @@ -430,23 +456,24 @@ class Bot(object): def load_plugin_config(self, cls): name = cls.__name__.lower() - if name.startswith('plugin'): - name = name[6:] + if name.endswith('plugin'): + name = name[:-6] path = os.path.join( self.config.plugin_config_dir, name) + '.' + self.config.plugin_config_format - if not os.path.exists(path): - if hasattr(cls, 'config_cls'): - return cls.config_cls() - return + data = {} + if name in self.config.plugin_config: + data = self.config.plugin_config[name] - with open(path, 'r') as f: - data = Serializer.loads(self.config.plugin_config_format, f.read()) + if os.path.exists(path): + with open(path, 'r') as f: + data.update(Serializer.loads(self.config.plugin_config_format, f.read())) if hasattr(cls, 'config_cls'): inst = cls.config_cls() - inst.update(data) + if data: + inst.update(data) return inst return data diff --git a/disco/bot/command.py b/disco/bot/command.py index 2546c9d..3d0398a 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -5,9 +5,11 @@ from holster.enum import Enum from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property -REGEX_FMT = '({})' -ARGS_REGEX = '( (.*)$|$)' -MENTION_RE = re.compile('<@!?([0-9]+)>') +ARGS_REGEX = '(?: ((?:\n|.)*)$|$)' + +USER_MENTION_RE = re.compile('<@!?([0-9]+)>') +ROLE_MENTION_RE = re.compile('<@&([0-9]+)>') +CHANNEL_MENTION_RE = re.compile('<#([0-9]+)>') CommandLevels = Enum( DEFAULT=0, @@ -42,34 +44,52 @@ class CommandEvent(object): self.command = command self.msg = msg self.match = match - self.name = self.match.group(1) - self.args = [i for i in self.match.group(2).strip().split(' ') if i] + self.name = self.match.group(0) + self.args = [] + + if self.match.group(1): + self.args = [i for i in self.match.group(1).strip().split(' ') if i] + + @property + def codeblock(self): + if '`' not in self.msg.content: + return ' '.join(self.args) + + _, src = self.msg.content.split('`', 1) + src = '`' + src + + if src.startswith('```') and src.endswith('```'): + src = src[3:-3] + elif src.startswith('`') and src.endswith('`'): + src = src[1:-1] + + return src @cached_property def member(self): """ - Guild member (if relevant) for the user that created the message + Guild member (if relevant) for the user that created the message. """ return self.guild.get_member(self.author) @property def channel(self): """ - Channel the message was created in + Channel the message was created in. """ return self.msg.channel @property def guild(self): """ - Guild (if relevant) the message was created in + Guild (if relevant) the message was created in. """ return self.msg.guild @property def author(self): """ - Author of the message + Author of the message. """ return self.msg.author @@ -107,61 +127,106 @@ class Command(object): self.plugin = plugin self.func = func self.triggers = [trigger] + + self.dispatch_func = None + self.raw_args = None + self.args = None + self.level = None + self.group = None + self.is_regex = None + self.oob = False + self.context = {} + self.metadata = {} + self.update(*args, **kwargs) - def update(self, args=None, level=None, aliases=None, group=None, is_regex=None): + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def get_docstring(self): + return (self.func.__doc__ or '').format(**self.context) + + def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False, context=None, **kwargs): self.triggers += aliases or [] - def resolve_role(ctx, id): - return ctx.msg.guild.roles.get(id) + def resolve_role(ctx, rid): + return ctx.msg.guild.roles.get(rid) + + def resolve_user(ctx, uid): + if isinstance(uid, int): + if uid in ctx.msg.mentions: + return ctx.msg.mentions.get(uid) + else: + return ctx.msg.client.state.users.get(uid) + else: + return ctx.msg.client.state.users.select_one(username=uid[0], discriminator=uid[1]) + + def resolve_channel(ctx, cid): + if isinstance(cid, (int, long)): + return ctx.msg.guild.channels.get(cid) + else: + return ctx.msg.guild.channels.select_one(name=cid) - def resolve_user(ctx, id): - return ctx.msg.mentions.get(id) + def resolve_guild(ctx, gid): + return ctx.msg.client.state.guilds.get(gid) + self.raw_args = args self.args = ArgumentSet.from_string(args or '', { - 'mention': self.mention_type([resolve_role, resolve_user]), - 'user': self.mention_type([resolve_user], force=True), - 'role': self.mention_type([resolve_role], force=True), + 'user': self.mention_type([resolve_user], USER_MENTION_RE, user=True), + 'role': self.mention_type([resolve_role], ROLE_MENTION_RE), + 'channel': self.mention_type([resolve_channel], CHANNEL_MENTION_RE, allow_plain=True), + 'guild': self.mention_type([resolve_guild]), }) self.level = level self.group = group self.is_regex = is_regex + self.oob = oob + self.context = context or {} + self.metadata = kwargs @staticmethod - def mention_type(getters, force=False): - def _f(ctx, i): - res = MENTION_RE.match(i) - if not res: - raise TypeError('Invalid mention: {}'.format(i)) - - id = int(res.group(1)) + def mention_type(getters, reg=None, user=False, allow_plain=False): + def _f(ctx, raw): + if raw.isdigit(): + resolved = int(raw) + elif user and raw.count('#') == 1 and raw.split('#')[-1].isdigit(): + username, discrim = raw.split('#') + resolved = (username, int(discrim)) + elif reg: + res = reg.match(raw) + if res: + resolved = int(res.group(1)) + else: + if allow_plain: + resolved = raw + else: + raise TypeError('Invalid mention: {}'.format(raw)) + else: + raise TypeError('Invalid mention: {}'.format(raw)) for getter in getters: - obj = getter(ctx, id) + obj = getter(ctx, resolved) if obj: return obj - if force: - raise TypeError('Cannot resolve mention: {}'.format(id)) - - return id + raise TypeError('Cannot resolve mention: {}'.format(raw)) return _f @cached_property def compiled_regex(self): """ - A compiled version of this command's regex + A compiled version of this command's regex. """ - return re.compile(self.regex) + return re.compile(self.regex, re.I) @property def regex(self): """ - The regex string that defines/triggers this command + The regex string that defines/triggers this command. """ if self.is_regex: - return REGEX_FMT.format('|'.join(self.triggers)) + return '|'.join(self.triggers) else: group = '' if self.group: @@ -169,7 +234,7 @@ class Command(object): group = '{}(?:\w+)? '.format(self.plugin.bot.group_abbrev.get(self.group)) else: group = self.group + ' ' - return REGEX_FMT.format('|'.join(['^' + group + trigger for trigger in self.triggers]) + ARGS_REGEX) + return '^{}(?:{})'.format(group, '|'.join(self.triggers)) + ARGS_REGEX def execute(self, event): """ @@ -189,8 +254,11 @@ class Command(object): )) try: - args = self.args.parse(event.args, ctx=event) + parsed_args = self.args.parse(event.args, ctx=event) except ArgumentError as e: raise CommandError(e.message) - return self.func(event, *args) + kwargs = {} + kwargs.update(self.context) + kwargs.update(parsed_args) + return self.plugin.dispatch('command', self, event, **kwargs) diff --git a/disco/bot/parser.py b/disco/bot/parser.py index 8f3483e..722abe6 100644 --- a/disco/bot/parser.py +++ b/disco/bot/parser.py @@ -2,9 +2,10 @@ import re import six import copy - # Regex which splits out argument parts -PARTS_RE = re.compile('(\<|\[)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\])') +PARTS_RE = re.compile('(\<|\[|\{)((?:\w+|\:|\||\.\.\.| (?:[0-9]+))+)(?:\>|\]|\})') + +BOOL_OPTS = {'yes': True, 'no': False, 'true': True, 'False': False, '1': True, '0': False} # Mapping of types TYPE_MAP = { @@ -14,6 +15,20 @@ TYPE_MAP = { 'snowflake': lambda ctx, data: int(data), } +try: + import dateparser + TYPE_MAP['duration'] = lambda ctx, data: dateparser.parse(data, settings={'TIMEZONE': 'UTC'}) +except ImportError: + pass + + +def to_bool(ctx, data): + if data in BOOL_OPTS: + return BOOL_OPTS[data] + raise TypeError + +TYPE_MAP['bool'] = to_bool + class ArgumentError(Exception): """ @@ -41,19 +56,20 @@ class Argument(object): self.name = None self.count = 1 self.required = False + self.flag = False self.types = None self.parse(raw) @property def true_count(self): """ - The true number of raw arguments this argument takes + The true number of raw arguments this argument takes. """ return self.count or 1 def parse(self, raw): """ - Attempts to parse arguments from their raw form + Attempts to parse arguments from their raw form. """ prefix, part = raw @@ -62,23 +78,27 @@ class Argument(object): else: self.required = False - if part.endswith('...'): - part = part[:-3] - self.count = 0 - elif ' ' in part: - split = part.split(' ', 1) - part, self.count = split[0], int(split[1]) + # Whether this is a flag + self.flag = (prefix == '{') + + if not self.flag: + if part.endswith('...'): + part = part[:-3] + self.count = 0 + elif ' ' in part: + split = part.split(' ', 1) + part, self.count = split[0], int(split[1]) - if ':' in part: - part, typeinfo = part.split(':') - self.types = typeinfo.split('|') + if ':' in part: + part, typeinfo = part.split(':') + self.types = typeinfo.split('|') self.name = part.strip() class ArgumentSet(object): """ - A set of :class:`Argument` instances which forms a larger argument specification + A set of :class:`Argument` instances which forms a larger argument specification. Attributes ---------- @@ -95,7 +115,7 @@ class ArgumentSet(object): @classmethod def from_string(cls, line, custom_types=None): """ - Creates a new :class:`ArgumentSet` from a given argument string specification + Creates a new :class:`ArgumentSet` from a given argument string specification. """ args = cls(custom_types=custom_types) @@ -131,7 +151,7 @@ class ArgumentSet(object): def append(self, arg): """ - Add a new :class:`Argument` to this argument specification/set + Add a new :class:`Argument` to this argument specification/set. """ if self.args and not self.args[-1].required and arg.required: raise Exception('Required argument cannot come after an optional argument') @@ -145,9 +165,23 @@ class ArgumentSet(object): """ Parse a string of raw arguments into this argument specification. """ - parsed = [] + parsed = {} + + flags = {i.name: i for i in self.args if i.flag} + if flags: + new_rawargs = [] + + for offset, raw in enumerate(rawargs): + if raw.startswith('-'): + raw = raw.lstrip('-') + if raw in flags: + parsed[raw] = True + continue + new_rawargs.append(raw) + + rawargs = new_rawargs - for index, arg in enumerate(self.args): + for index, arg in enumerate((arg for arg in self.args if not arg.flag)): if not arg.required and index + arg.true_count > len(rawargs): continue @@ -171,20 +205,20 @@ class ArgumentSet(object): if (not arg.types or arg.types == ['str']) and isinstance(raw, list): raw = ' '.join(raw) - parsed.append(raw) + parsed[arg.name] = raw return parsed @property def length(self): """ - The number of arguments in this set/specification + The number of arguments in this set/specification. """ return len(self.args) @property def required_length(self): """ - The number of required arguments to compile this set/specificaiton + The number of required arguments to compile this set/specificaiton. """ return sum([i.true_count for i in self.args if i.required]) diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 5ee6803..aeed5b3 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -1,9 +1,11 @@ import six +import types import gevent import inspect import weakref import functools +from gevent.event import AsyncResult from holster.emitter import Priority from disco.util.logging import LoggingClass @@ -18,8 +20,8 @@ class PluginDeco(object): Prio = Priority # TODO: dont smash class methods - @staticmethod - def add_meta_deco(meta): + @classmethod + def add_meta_deco(cls, meta): def deco(f): if not hasattr(f, 'meta'): f.meta = [] @@ -40,33 +42,33 @@ class PluginDeco(object): return deco @classmethod - def listen(cls, event_name, priority=None): + def listen(cls, *args, **kwargs): """ - Binds the function to listen for a given event name + Binds the function to listen for a given event name. """ return cls.add_meta_deco({ 'type': 'listener', 'what': 'event', - 'desc': event_name, - 'priority': priority + 'args': args, + 'kwargs': kwargs, }) @classmethod - def listen_packet(cls, op, priority=None): + def listen_packet(cls, *args, **kwargs): """ - Binds the function to listen for a given gateway op code + Binds the function to listen for a given gateway op code. """ return cls.add_meta_deco({ 'type': 'listener', 'what': 'packet', - 'desc': op, - 'priority': priority, + 'args': args, + 'kwargs': kwargs, }) @classmethod def command(cls, *args, **kwargs): """ - Creates a new command attached to the function + Creates a new command attached to the function. """ return cls.add_meta_deco({ 'type': 'command', @@ -77,7 +79,7 @@ class PluginDeco(object): @classmethod def pre_command(cls): """ - Runs a function before a command is triggered + Runs a function before a command is triggered. """ return cls.add_meta_deco({ 'type': 'pre_command', @@ -86,7 +88,7 @@ class PluginDeco(object): @classmethod def post_command(cls): """ - Runs a function after a command is triggered + Runs a function after a command is triggered. """ return cls.add_meta_deco({ 'type': 'post_command', @@ -95,7 +97,7 @@ class PluginDeco(object): @classmethod def pre_listener(cls): """ - Runs a function before a listener is triggered + Runs a function before a listener is triggered. """ return cls.add_meta_deco({ 'type': 'pre_listener', @@ -104,7 +106,7 @@ class PluginDeco(object): @classmethod def post_listener(cls): """ - Runs a function after a listener is triggered + Runs a function after a listener is triggered. """ return cls.add_meta_deco({ 'type': 'post_listener', @@ -113,7 +115,7 @@ class PluginDeco(object): @classmethod def schedule(cls, *args, **kwargs): """ - Runs a function repeatedly, waiting for a specified interval + Runs a function repeatedly, waiting for a specified interval. """ return cls.add_meta_deco({ 'type': 'schedule', @@ -153,46 +155,101 @@ class Plugin(LoggingClass, PluginDeco): self.storage = bot.storage self.config = config + # General declartions + self.listeners = [] + self.commands = [] + self.schedules = {} + self.greenlets = weakref.WeakSet() + self._pre = {} + self._post = {} + + # This is an array of all meta functions we sniff at init + self.meta_funcs = [] + + for name, member in inspect.getmembers(self, predicate=inspect.ismethod): + if hasattr(member, 'meta'): + self.meta_funcs.append(member) + + # Unsmash local functions + if hasattr(Plugin, name): + method = types.MethodType(getattr(Plugin, name), self, self.__class__) + setattr(self, name, method) + + self.bind_all() + @property def name(self): return self.__class__.__name__ def bind_all(self): self.listeners = [] - self.commands = {} + self.commands = [] self.schedules = {} self.greenlets = weakref.WeakSet() self._pre = {'command': [], 'listener': []} self._post = {'command': [], 'listener': []} - # TODO: when handling events/commands we need to track the greenlet in - # the greenlets set so we can termiante long running commands/listeners - # on reload. + for member in self.meta_funcs: + for meta in member.meta: + self.bind_meta(member, meta) + + def bind_meta(self, member, meta): + if meta['type'] == 'listener': + self.register_listener(member, meta['what'], *meta['args'], **meta['kwargs']) + elif meta['type'] == 'command': + # meta['kwargs']['update'] = True + self.register_command(member, *meta['args'], **meta['kwargs']) + elif meta['type'] == 'schedule': + self.register_schedule(member, *meta['args'], **meta['kwargs']) + elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): + when, typ = meta['type'].split('_', 1) + self.register_trigger(typ, when, member) + + def handle_exception(self, greenlet, event): + pass + + def wait_for_event(self, event_name, **kwargs): + result = AsyncResult() + listener = None + + def _event_callback(event): + for k, v in kwargs.items(): + if getattr(event, k) != v: + break + else: + listener.remove() + return result.set(event) - for name, member in inspect.getmembers(self, predicate=inspect.ismethod): - if hasattr(member, 'meta'): - for meta in member.meta: - if meta['type'] == 'listener': - self.register_listener(member, meta['what'], meta['desc'], meta['priority']) - elif meta['type'] == 'command': - meta['kwargs']['update'] = True - self.register_command(member, *meta['args'], **meta['kwargs']) - elif meta['type'] == 'schedule': - self.register_schedule(member, *meta['args'], **meta['kwargs']) - elif meta['type'].startswith('pre_') or meta['type'].startswith('post_'): - when, typ = meta['type'].split('_', 1) - self.register_trigger(typ, when, member) - - def spawn(self, method, *args, **kwargs): - obj = gevent.spawn(method, *args, **kwargs) + listener = self.bot.client.events.on(event_name, _event_callback) + + return result + + def spawn_wrap(self, spawner, method, *args, **kwargs): + def wrapped(*args, **kwargs): + self.ctx['plugin'] = self + try: + res = method(*args, **kwargs) + return res + finally: + self.ctx.drop() + + obj = spawner(wrapped, *args, **kwargs) self.greenlets.add(obj) return obj + def spawn(self, *args, **kwargs): + return self.spawn_wrap(gevent.spawn, *args, **kwargs) + + def spawn_later(self, delay, *args, **kwargs): + return self.spawn_wrap(functools.partial(gevent.spawn_later, delay), *args, **kwargs) + def execute(self, event): """ - Executes a CommandEvent this plugin owns + Executes a CommandEvent this plugin owns. """ + if not event.command.oob: + self.greenlets.add(gevent.getcurrent()) try: return event.command.execute(event) except CommandError as e: @@ -203,11 +260,18 @@ class Plugin(LoggingClass, PluginDeco): def register_trigger(self, typ, when, func): """ - Registers a trigger + Registers a trigger. """ getattr(self, '_' + when)[typ].append(func) - def _dispatch(self, typ, func, event, *args, **kwargs): + def dispatch(self, typ, func, event, *args, **kwargs): + # Link the greenlet with our exception handler + gevent.getcurrent().link_exception(lambda g: self.handle_exception(g, event)) + + # TODO: this is ugly + if typ != 'command': + self.greenlets.add(gevent.getcurrent()) + self.ctx['plugin'] = self if hasattr(event, 'guild'): @@ -218,7 +282,7 @@ class Plugin(LoggingClass, PluginDeco): self.ctx['user'] = event.author for pre in self._pre[typ]: - event = pre(event, args, kwargs) + event = pre(func, event, args, kwargs) if event is None: return False @@ -226,13 +290,13 @@ class Plugin(LoggingClass, PluginDeco): result = func(event, *args, **kwargs) for post in self._post[typ]: - post(event, args, kwargs, result) + post(func, event, args, kwargs, result) return True - def register_listener(self, func, what, desc, priority): + def register_listener(self, func, what, *args, **kwargs): """ - Registers a listener + Registers a listener. Parameters ---------- @@ -242,17 +306,13 @@ class Plugin(LoggingClass, PluginDeco): The function to be registered. desc The descriptor of the event/packet. - priority : Priority - The priority of this listener. """ - func = functools.partial(self._dispatch, 'listener', func) - - priority = priority or Priority.NONE + args = list(args) + [functools.partial(self.dispatch, 'listener', func)] if what == 'event': - li = self.bot.client.events.on(desc, func, priority=priority) + li = self.bot.client.events.on(*args, **kwargs) elif what == 'packet': - li = self.bot.client.packets.on(desc, func, priority=priority) + li = self.bot.client.packets.on(*args, **kwargs) else: raise Exception('Invalid listener what: {}'.format(what)) @@ -260,7 +320,7 @@ class Plugin(LoggingClass, PluginDeco): def register_command(self, func, *args, **kwargs): """ - Registers a command + Registers a command. Parameters ---------- @@ -272,11 +332,7 @@ class Plugin(LoggingClass, PluginDeco): Keyword arguments to pass onto the :class:`disco.bot.command.Command` object. """ - if kwargs.pop('update', False) and func.__name__ in self.commands: - self.commands[func.__name__].update(*args, **kwargs) - else: - wrapped = functools.partial(self._dispatch, 'command', func) - self.commands[func.__name__] = Command(self, wrapped, *args, **kwargs) + self.commands.append(Command(self, func, *args, **kwargs)) def register_schedule(self, func, interval, repeat=True, init=True): """ @@ -289,8 +345,13 @@ class Plugin(LoggingClass, PluginDeco): The function to be registered. interval : int Interval (in seconds) to repeat the function on. + repeat : bool + Whether this schedule is repeating (or one time). + init : bool + Whether to run this schedule once immediatly, or wait for the first + scheduled iteration. """ - def repeat(): + def repeat_func(): if init: func() @@ -300,17 +361,17 @@ class Plugin(LoggingClass, PluginDeco): if not repeat: break - self.schedules[func.__name__] = self.spawn(repeat) + self.schedules[func.__name__] = self.spawn(repeat_func) - def load(self): + def load(self, ctx): """ - Called when the plugin is loaded + Called when the plugin is loaded. """ - self.bind_all() + pass - def unload(self): + def unload(self, ctx): """ - Called when the plugin is unloaded + Called when the plugin is unloaded. """ for greenlet in self.greenlets: greenlet.kill() diff --git a/disco/bot/providers/disk.py b/disco/bot/providers/disk.py index 5cf1ca3..af259e1 100644 --- a/disco/bot/providers/disk.py +++ b/disco/bot/providers/disk.py @@ -13,6 +13,7 @@ class DiskProvider(BaseProvider): self.fsync = config.get('fsync', False) self.fsync_changes = config.get('fsync_changes', 1) + self.autosave_task = None self.change_count = 0 def autosave_loop(self, interval): diff --git a/disco/bot/providers/redis.py b/disco/bot/providers/redis.py index d0c2e5b..f5e1375 100644 --- a/disco/bot/providers/redis.py +++ b/disco/bot/providers/redis.py @@ -10,32 +10,39 @@ from .base import BaseProvider, SEP_SENTINEL class RedisProvider(BaseProvider): def __init__(self, config): - self.config = config + super(RedisProvider, self).__init__(config) + self.format = config.get('format', 'pickle') + self.conn = None def load(self): - self.redis = redis.Redis( + self.conn = redis.Redis( host=self.config.get('host', 'localhost'), port=self.config.get('port', 6379), db=self.config.get('db', 0)) def exists(self, key): - return self.db.exists(key) + return self.conn.exists(key) def keys(self, other): count = other.count(SEP_SENTINEL) + 1 - for key in self.db.scan_iter(u'{}*'.format(other)): + for key in self.conn.scan_iter(u'{}*'.format(other)): + key = key.decode('utf-8') if key.count(SEP_SENTINEL) == count: yield key def get_many(self, keys): - for key, value in izip(keys, self.db.mget(keys)): + keys = list(keys) + if not len(keys): + raise StopIteration + + for key, value in izip(keys, self.conn.mget(keys)): yield (key, Serializer.loads(self.format, value)) def get(self, key): - return Serializer.loads(self.format, self.db.get(key)) + return Serializer.loads(self.format, self.conn.get(key)) def set(self, key, value): - self.db.set(key, Serializer.dumps(self.format, value)) + self.conn.set(key, Serializer.dumps(self.format, value)) - def delete(self, key, value): - self.db.delete(key) + def delete(self, key): + self.conn.delete(key) diff --git a/disco/bot/providers/rocksdb.py b/disco/bot/providers/rocksdb.py index 0062d79..986268d 100644 --- a/disco/bot/providers/rocksdb.py +++ b/disco/bot/providers/rocksdb.py @@ -12,11 +12,13 @@ from .base import BaseProvider, SEP_SENTINEL class RocksDBProvider(BaseProvider): def __init__(self, config): - self.config = config + super(RocksDBProvider, self).__init__(config) self.format = config.get('format', 'pickle') self.path = config.get('path', 'storage.db') + self.db = None - def k(self, k): + @staticmethod + def k(k): return bytes(k) if six.PY3 else str(k.encode('utf-8')) def load(self): diff --git a/disco/cli.py b/disco/cli.py index 951fd96..10aa9f3 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -18,14 +18,13 @@ parser.add_argument('--config', help='Configuration file', default='config.yaml' parser.add_argument('--token', help='Bot Authentication Token', default=None) parser.add_argument('--shard-count', help='Total number of shards', default=None) parser.add_argument('--shard-id', help='Current shard number/id', default=None) +parser.add_argument('--shard-auto', help='Automatically run all shards', action='store_true', default=False) parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None) parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None) parser.add_argument('--encoder', help='encoder for gateway data', default=None) parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False) parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[]) -logging.basicConfig(level=logging.INFO) - def disco_main(run=False): """ @@ -42,6 +41,7 @@ def disco_main(run=False): from disco.client import Client, ClientConfig from disco.bot import Bot, BotConfig from disco.util.token import is_valid_token + from disco.util.logging import setup_logging if os.path.exists(args.config): config = ClientConfig.from_file(args.config) @@ -56,12 +56,23 @@ def disco_main(run=False): print('Invalid token passed') return + if args.shard_auto: + from disco.gateway.sharder import AutoSharder + AutoSharder(config).run() + return + + # TODO: make configurable + setup_logging(level=logging.INFO) + client = Client(config) bot = None if args.run_bot or hasattr(config, 'bot'): bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig() - bot_config.plugins += args.plugin + if not hasattr(bot_config, 'plugins'): + bot_config.plugins = args.plugin + else: + bot_config.plugins += args.plugin bot = Bot(client, bot_config) if run: diff --git a/disco/client.py b/disco/client.py index 54dacd5..a5791ee 100644 --- a/disco/client.py +++ b/disco/client.py @@ -1,3 +1,4 @@ +import time import gevent from holster.emitter import Emitter @@ -5,24 +6,28 @@ from holster.emitter import Emitter from disco.state import State, StateConfig from disco.api.client import APIClient from disco.gateway.client import GatewayClient +from disco.gateway.packets import OPCode +from disco.types.user import Status, Game from disco.util.config import Config from disco.util.logging import LoggingClass from disco.util.backdoor import DiscoBackdoorServer -class ClientConfig(LoggingClass, Config): +class ClientConfig(Config): """ Configuration for the :class:`Client`. Attributes ---------- token : str - Discord authentication token, ca be validated using the + Discord authentication token, can be validated using the :func:`disco.util.token.is_valid_token` function. shard_id : int The shard ID for the current client instance. shard_count : int The total count of shards running. + max_reconnects : int + The maximum number of connection retries to make before giving up (0 = never give up). manhole_enable : bool Whether to enable the manhole (e.g. console backdoor server) utility. manhole_bind : tuple(str, int) @@ -36,14 +41,15 @@ class ClientConfig(LoggingClass, Config): token = "" shard_id = 0 shard_count = 1 + max_reconnects = 5 - manhole_enable = True + manhole_enable = False manhole_bind = ('127.0.0.1', 8484) encoder = 'json' -class Client(object): +class Client(LoggingClass): """ Class representing the base entry point that should be used in almost all implementation cases. This class wraps the functionality of both the REST API @@ -82,8 +88,8 @@ class Client(object): self.events = Emitter(gevent.spawn) self.packets = Emitter(gevent.spawn) - self.api = APIClient(self) - self.gw = GatewayClient(self, self.config.encoder) + self.api = APIClient(self.config.token, self) + self.gw = GatewayClient(self, self.config.max_reconnects, self.config.encoder) self.state = State(self, StateConfig(self.config.get('state', {}))) if self.config.manhole_enable: @@ -95,18 +101,37 @@ class Client(object): } self.manhole = DiscoBackdoorServer(self.config.manhole_bind, - banner='Disco Manhole', - localf=lambda: self.manhole_locals) + banner='Disco Manhole', + localf=lambda: self.manhole_locals) self.manhole.start() + def update_presence(self, game=None, status=None, afk=False, since=0.0): + if game and not isinstance(game, Game): + raise TypeError('Game must be a Game model') + + if status is Status.IDLE and not since: + since = int(time.time() * 1000) + + payload = { + 'afk': afk, + 'since': since, + 'status': status.value.lower(), + 'game': None, + } + + if game: + payload['game'] = game.to_dict() + + self.gw.send(OPCode.STATUS_UPDATE, payload) + def run(self): """ - Run the client (e.g. the :class:`GatewayClient`) in a new greenlet + Run the client (e.g. the :class:`GatewayClient`) in a new greenlet. """ return gevent.spawn(self.gw.run) def run_forever(self): """ - Run the client (e.g. the :class:`GatewayClient`) in the current greenlet + Run the client (e.g. the :class:`GatewayClient`) in the current greenlet. """ return self.gw.run() diff --git a/disco/gateway/client.py b/disco/gateway/client.py index b3f8012..7386bb5 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -15,16 +15,21 @@ TEN_MEGABYTES = 10490000 class GatewayClient(LoggingClass): GATEWAY_VERSION = 6 - MAX_RECONNECTS = 5 - def __init__(self, client, encoder='json'): + def __init__(self, client, max_reconnects=5, encoder='json', ipc=None): super(GatewayClient, self).__init__() self.client = client + self.max_reconnects = max_reconnects self.encoder = ENCODERS[encoder] self.events = client.events self.packets = client.packets + # IPC for shards + if ipc: + self.shards = ipc.get_shards() + self.ipc = ipc + # Its actually 60, 120 but lets give ourselves a buffer self.limiter = SimpleLimiter(60, 130) @@ -37,6 +42,7 @@ class GatewayClient(LoggingClass): # Bind to ready payload self.events.on('Ready', self.on_ready) + self.events.on('Resumed', self.on_resumed) # Websocket connection self.ws = None @@ -76,15 +82,15 @@ class GatewayClient(LoggingClass): self.log.debug('Dispatching %s', obj.__class__.__name__) self.client.events.emit(obj.__class__.__name__, obj) - def handle_heartbeat(self, packet): + def handle_heartbeat(self, _): self._send(OPCode.HEARTBEAT, self.seq) - def handle_reconnect(self, packet): + def handle_reconnect(self, _): self.log.warning('Received RECONNECT request, forcing a fresh reconnect') self.session_id = None self.ws.close() - def handle_invalid_session(self, packet): + def handle_invalid_session(self, _): self.log.warning('Recieved INVALID_SESSION, forcing a fresh reconnect') self.session_id = None self.ws.close() @@ -98,14 +104,21 @@ class GatewayClient(LoggingClass): self.session_id = ready.session_id self.reconnects = 0 - def connect_and_run(self): - if not self._cached_gateway_url: - self._cached_gateway_url = self.client.api.gateway( - version=self.GATEWAY_VERSION, - encoding=self.encoder.TYPE) + def on_resumed(self, _): + self.log.info('Recieved RESUMED') + self.reconnects = 0 + + def connect_and_run(self, gateway_url=None): + if not gateway_url: + if not self._cached_gateway_url: + self._cached_gateway_url = self.client.api.gateway_get()['url'] - self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) - self.ws = Websocket(self._cached_gateway_url) + gateway_url = self._cached_gateway_url + + gateway_url += '?v={}&encoding={}'.format(self.GATEWAY_VERSION, self.encoder.TYPE) + + self.log.info('Opening websocket connection to URL `%s`', gateway_url) + self.ws = Websocket(gateway_url) self.ws.emitter.on('on_open', self.on_open) self.ws.emitter.on('on_error', self.on_error) self.ws.emitter.on('on_close', self.on_close) @@ -153,8 +166,8 @@ class GatewayClient(LoggingClass): 'compress': True, 'large_threshold': 250, 'shard': [ - self.client.config.shard_id, - self.client.config.shard_count, + int(self.client.config.shard_id), + int(self.client.config.shard_count), ], 'properties': { '$os': 'linux', @@ -165,15 +178,22 @@ class GatewayClient(LoggingClass): }) def on_close(self, code, reason): + # Kill heartbeater, a reconnect/resume will trigger a HELLO which will + # respawn it + if self._heartbeat_task: + self._heartbeat_task.kill() + + # If we're quitting, just break out of here if self.shutting_down: self.log.info('WS Closed: shutting down') return + # Track reconnect attempts self.reconnects += 1 self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) - if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS: - raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) + if self.max_reconnects and self.reconnects > self.max_reconnects: + raise Exception('Failed to reconnect after {} attempts, giving up'.format(self.max_reconnects)) # Don't resume for these error codes if code and 4000 <= code <= 4010: diff --git a/disco/gateway/encoding/base.py b/disco/gateway/encoding/base.py index e663cf6..f4903d9 100644 --- a/disco/gateway/encoding/base.py +++ b/disco/gateway/encoding/base.py @@ -1,7 +1,9 @@ from websocket import ABNF +from holster.interface import Interface -class BaseEncoder(object): + +class BaseEncoder(Interface): TYPE = None OPCODE = ABNF.OPCODE_TEXT diff --git a/disco/gateway/encoding/json.py b/disco/gateway/encoding/json.py index 8810198..8550ac5 100644 --- a/disco/gateway/encoding/json.py +++ b/disco/gateway/encoding/json.py @@ -1,7 +1,5 @@ from __future__ import absolute_import, print_function -import six - try: import ujson as json except ImportError: diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6d55690..9d9d323 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -4,20 +4,20 @@ import inflection import six from disco.types.user import User, Presence -from disco.types.channel import Channel -from disco.types.message import Message +from disco.types.channel import Channel, PermissionOverwrite +from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState -from disco.types.guild import Guild, GuildMember, Role +from disco.types.guild import Guild, GuildMember, Role, GuildEmoji -from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime +from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} class GatewayEventMeta(ModelMeta): - def __new__(cls, name, parents, dct): - obj = super(GatewayEventMeta, cls).__new__(cls, name, parents, dct) + def __new__(mcs, name, parents, dct): + obj = super(GatewayEventMeta, mcs).__new__(mcs, name, parents, dct) if name != 'GatewayEvent': EVENTS_MAP[inflection.underscore(name).upper()] = obj @@ -64,22 +64,21 @@ class GatewayEvent(six.with_metaclass(GatewayEventMeta, Model)): return cls(obj, client) def __getattr__(self, name): - if hasattr(self, '_wraps_model'): - modname, _ = self._wraps_model - if hasattr(self, modname) and hasattr(getattr(self, modname), name): - return getattr(getattr(self, modname), name) - raise AttributeError(name) + if hasattr(self, '_proxy'): + return getattr(getattr(self, self._proxy), name) + return object.__getattribute__(self, name) -def debug(func=None): +def debug(func=None, match=None): def deco(cls): old_init = cls.__init__ def new_init(self, obj, *args, **kwargs): - if func: - print(func(obj)) - else: - print(obj) + if not match or match(obj): + if func: + print(func(obj)) + else: + print(obj) old_init(self, obj, *args, **kwargs) @@ -93,8 +92,16 @@ def wraps_model(model, alias=None): def deco(cls): cls._fields[alias] = Field(model) - cls._fields[alias].set_name(alias) + cls._fields[alias].name = alias cls._wraps_model = (alias, model) + cls._proxy = alias + return cls + return deco + + +def proxy(field): + def deco(cls): + cls._proxy = field return cls return deco @@ -103,49 +110,102 @@ class Ready(GatewayEvent): """ Sent after the initial gateway handshake is complete. Contains data required for bootstrapping the client's states. + + Attributes + ----- + version : int + The gateway version. + session_id : str + The session ID. + user : :class:`disco.types.user.User` + The user object for the authed account. + guilds : list[:class:`disco.types.guild.Guild` + All guilds this account is a member of. These are shallow guild objects. + private_channels list[:class:`disco.types.channel.Channel`] + All private channels (DMs) open for this account. """ version = Field(int, alias='v') session_id = Field(str) user = Field(User) - guilds = Field(listof(Guild)) - private_channels = Field(listof(Channel)) + guilds = ListField(Guild) + private_channels = ListField(Channel) + trace = ListField(str, alias='_trace') class Resumed(GatewayEvent): """ Sent after a resume completes. """ - pass + trace = ListField(str, alias='_trace') @wraps_model(Guild) class GuildCreate(GatewayEvent): """ - Sent when a guild is created, or becomes available. + Sent when a guild is joined, or becomes available. + + Attributes + ----- + guild : :class:`disco.types.guild.Guild` + The guild being created (e.g. joined) + unavailable : bool + If false, this guild is coming online from a previously unavailable state, + and if None, this is a normal guild join event. """ unavailable = Field(bool) + @property + def created(self): + """ + Shortcut property which is true when we actually joined the guild. + """ + return self.unavailable is None + @wraps_model(Guild) class GuildUpdate(GatewayEvent): """ Sent when a guild is updated. + + Attributes + ----- + guild : :class:`disco.types.guild.Guild` + The updated guild object. """ - pass class GuildDelete(GatewayEvent): """ - Sent when a guild is deleted, or becomes unavailable. + Sent when a guild is deleted, left, or becomes unavailable. + + Attributes + ----- + id : snowflake + The ID of the guild being deleted. + unavailable : bool + If true, this guild is becoming unavailable, if None this is a normal + guild leave event. """ id = Field(snowflake) unavailable = Field(bool) + @property + def deleted(self): + """ + Shortcut property which is true when we actually have left the guild. + """ + return self.unavailable is None + @wraps_model(Channel) class ChannelCreate(GatewayEvent): """ Sent when a channel is created. + + Attributes + ----- + channel : :class:`disco.types.channel.Channel` + The channel which was created. """ @@ -153,115 +213,236 @@ class ChannelCreate(GatewayEvent): class ChannelUpdate(ChannelCreate): """ Sent when a channel is updated. + + Attributes + ----- + channel : :class:`disco.types.channel.Channel` + The channel which was updated. """ - pass + overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') @wraps_model(Channel) class ChannelDelete(ChannelCreate): """ Sent when a channel is deleted. + + Attributes + ----- + channel : :class:`disco.types.channel.Channel` + The channel being deleted. """ - pass class ChannelPinsUpdate(GatewayEvent): """ Sent when a channel's pins are updated. + + Attributes + ----- + channel_id : snowflake + ID of the channel where pins where updated. + last_pin_timestap : datetime + The time the last message was pinned. """ channel_id = Field(snowflake) - last_pin_timestamp = Field(lazy_datetime) + last_pin_timestamp = Field(datetime) -@wraps_model(User) +@proxy(User) class GuildBanAdd(GatewayEvent): """ Sent when a user is banned from a guild. + + Attributes + ----- + guild_id : snowflake + The ID of the guild the user is being banned from. + user : :class:`disco.types.user.User` + The user being banned from the guild. """ - pass + guild_id = Field(snowflake) + user = Field(User) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) -@wraps_model(User) + +@proxy(User) class GuildBanRemove(GuildBanAdd): """ Sent when a user is unbanned from a guild. + + Attributes + ----- + guild_id : snowflake + The ID of the guild the user is being unbanned from. + user : :class:`disco.types.user.User` + The user being unbanned from the guild. """ - pass + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) class GuildEmojisUpdate(GatewayEvent): """ Sent when a guild's emojis are updated. + + Attributes + ----- + guild_id : snowflake + The ID of the guild the emojis are being updated in. + emojis : list[:class:`disco.types.guild.Emoji`] + The new set of emojis for the guild """ - pass + guild_id = Field(snowflake) + emojis = ListField(GuildEmoji) class GuildIntegrationsUpdate(GatewayEvent): """ Sent when a guild's integrations are updated. + + Attributes + ----- + guild_id : snowflake + The ID of the guild integrations where updated in. """ - pass + guild_id = Field(snowflake) class GuildMembersChunk(GatewayEvent): """ Sent in response to a member's chunk request. + + Attributes + ----- + guild_id : snowflake + The ID of the guild this member chunk is for. + members : list[:class:`disco.types.guild.GuildMember`] + The chunk of members. """ guild_id = Field(snowflake) - members = Field(listof(GuildMember)) + members = ListField(GuildMember) + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) @wraps_model(GuildMember, alias='member') class GuildMemberAdd(GatewayEvent): """ Sent when a user joins a guild. + + Attributes + ----- + member : :class:`disco.types.guild.GuildMember` + The member that has joined the guild. """ - pass +@proxy('user') class GuildMemberRemove(GatewayEvent): """ Sent when a user leaves a guild (via leaving, kicking, or banning). + + Attributes + ----- + guild_id : snowflake + The ID of the guild the member left from. + user : :class:`disco.types.user.User` + The user who was removed from the guild. """ - guild_id = Field(snowflake) user = Field(User) + guild_id = Field(snowflake) + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) @wraps_model(GuildMember, alias='member') class GuildMemberUpdate(GatewayEvent): """ Sent when a guilds member is updated. + + Attributes + ----- + member : :class:`disco.types.guild.GuildMember` + The member being updated """ - pass +@proxy('role') class GuildRoleCreate(GatewayEvent): """ Sent when a role is created. + + Attributes + ----- + guild_id : snowflake + The ID of the guild where the role was created. + role : :class:`disco.types.guild.Role` + The role that was created. """ - guild_id = Field(snowflake) role = Field(Role) + guild_id = Field(snowflake) + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) +@proxy('role') class GuildRoleUpdate(GuildRoleCreate): """ Sent when a role is updated. + + Attributes + ----- + guild_id : snowflake + The ID of the guild where the role was created. + role : :class:`disco.types.guild.Role` + The role that was created. """ - pass + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) class GuildRoleDelete(GatewayEvent): """ Sent when a role is deleted. + + Attributes + ----- + guild_id : snowflake + The ID of the guild where the role is being deleted. + role_id : snowflake + The id of the role being deleted. """ guild_id = Field(snowflake) role_id = Field(snowflake) + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + @wraps_model(Message) class MessageCreate(GatewayEvent): """ Sent when a message is created. + + Attributes + ----- + message : :class:`disco.types.message.Message` + The message being created. """ @@ -269,55 +450,124 @@ class MessageCreate(GatewayEvent): class MessageUpdate(MessageCreate): """ Sent when a message is updated/edited. + + Attributes + ----- + message : :class:`disco.types.message.Message` + The message being updated. """ - pass class MessageDelete(GatewayEvent): """ Sent when a message is deleted. + + Attributes + ----- + id : snowflake + The ID of message being deleted. + channel_id : snowflake + The ID of the channel the message was deleted in. """ id = Field(snowflake) channel_id = Field(snowflake) + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + class MessageDeleteBulk(GatewayEvent): """ Sent when multiple messages are deleted from a channel. + + Attributes + ----- + channel_id : snowflake + The channel the messages are being deleted in. + ids : list[snowflake] + List of messages being deleted in the channel. """ channel_id = Field(snowflake) - ids = Field(listof(snowflake)) + ids = ListField(snowflake) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild @wraps_model(Presence) class PresenceUpdate(GatewayEvent): """ Sent when a user's presence is updated. + + Attributes + ----- + presence : :class:`disco.types.user.Presence` + The updated presence object. + guild_id : snowflake + The guild this presence update is for. + roles : list[snowflake] + List of roles the user from the presence is part of. """ guild_id = Field(snowflake) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) + + @property + def guild(self): + return self.client.state.guilds.get(self.guild_id) class TypingStart(GatewayEvent): """ Sent when a user begins typing in a channel. + + Attributes + ----- + channel_id : snowflake + The ID of the channel where the user is typing. + user_id : snowflake + The ID of the user who is typing. + timestamp : datetime + When the user started typing. """ channel_id = Field(snowflake) user_id = Field(snowflake) - timestamp = Field(snowflake) + timestamp = Field(datetime) @wraps_model(VoiceState, alias='state') class VoiceStateUpdate(GatewayEvent): """ Sent when a users voice state changes. + + Attributes + ----- + state : :class:`disco.models.voice.VoiceState` + The voice state which was updated. """ - pass class VoiceServerUpdate(GatewayEvent): """ Sent when a voice server is updated. + + Attributes + ----- + token : str + The token for the voice server. + endpoint : str + The endpoint for the voice server. + guild_id : snowflake + The guild ID this voice server update is for. """ token = Field(str) endpoint = Field(str) @@ -327,6 +577,94 @@ class VoiceServerUpdate(GatewayEvent): class WebhooksUpdate(GatewayEvent): """ Sent when a channels webhooks are updated. + + Attributes + ----- + channel_id : snowflake + The channel ID this webhooks update is for. + guild_id : snowflake + The guild ID this webhooks update is for. """ channel_id = Field(snowflake) guild_id = Field(snowflake) + + +class MessageReactionAdd(GatewayEvent): + """ + Sent when a reaction is added to a message. + + Attributes + ---------- + channel_id : snowflake + The channel ID the message is in. + messsage_id : snowflake + The ID of the message for which the reaction was added too. + user_id : snowflake + The ID of the user who added the reaction. + emoji : :class:`disco.types.message.MessageReactionEmoji` + The emoji which was added. + """ + channel_id = Field(snowflake) + message_id = Field(snowflake) + user_id = Field(snowflake) + emoji = Field(MessageReactionEmoji) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + + +class MessageReactionRemove(GatewayEvent): + """ + Sent when a reaction is removed from a message. + + Attributes + ---------- + channel_id : snowflake + The channel ID the message is in. + messsage_id : snowflake + The ID of the message for which the reaction was removed from. + user_id : snowflake + The ID of the user who originally added the reaction. + emoji : :class:`disco.types.message.MessageReactionEmoji` + The emoji which was removed. + """ + channel_id = Field(snowflake) + message_id = Field(snowflake) + user_id = Field(snowflake) + emoji = Field(MessageReactionEmoji) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + + +class MessageReactionRemoveAll(GatewayEvent): + """ + Sent when all reactions are removed from a message. + + Attributes + ---------- + channel_id : snowflake + The channel ID the message is in. + message_id : snowflake + The ID of the message for which the reactions where removed from. + """ + channel_id = Field(snowflake) + message_id = Field(snowflake) + + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild diff --git a/disco/gateway/ipc.py b/disco/gateway/ipc.py new file mode 100644 index 0000000..bcd3383 --- /dev/null +++ b/disco/gateway/ipc.py @@ -0,0 +1,91 @@ +import random +import gevent +import string +import weakref + +from holster.enum import Enum + +from disco.util.logging import LoggingClass +from disco.util.serializer import dump_function, load_function + + +def get_random_str(size): + return ''.join([random.choice(string.printable) for _ in range(size)]) + + +IPCMessageType = Enum( + 'CALL_FUNC', + 'GET_ATTR', + 'EXECUTE', + 'RESPONSE', +) + + +class GIPCProxy(LoggingClass): + def __init__(self, obj, pipe): + super(GIPCProxy, self).__init__() + self.obj = obj + self.pipe = pipe + self.results = weakref.WeakValueDictionary() + gevent.spawn(self.read_loop) + + def resolve(self, parts): + base = self.obj + for part in parts: + base = getattr(base, part) + + return base + + def send(self, typ, data): + self.pipe.put((typ.value, data)) + + def handle(self, mtype, data): + if mtype == IPCMessageType.CALL_FUNC: + nonce, func, args, kwargs = data + res = self.resolve(func)(*args, **kwargs) + self.send(IPCMessageType.RESPONSE, (nonce, res)) + elif mtype == IPCMessageType.GET_ATTR: + nonce, path = data + self.send(IPCMessageType.RESPONSE, (nonce, self.resolve(path))) + elif mtype == IPCMessageType.EXECUTE: + nonce, raw = data + func = load_function(raw) + try: + result = func(self.obj) + except Exception: + self.log.exception('Failed to EXECUTE: ') + result = None + + self.send(IPCMessageType.RESPONSE, (nonce, result)) + elif mtype == IPCMessageType.RESPONSE: + nonce, res = data + if nonce in self.results: + self.results[nonce].set(res) + + def read_loop(self): + while True: + mtype, data = self.pipe.get() + + try: + self.handle(mtype, data) + except: + self.log.exception('Error in GIPCProxy:') + + def execute(self, func): + nonce = get_random_str(32) + raw = dump_function(func) + self.results[nonce] = result = gevent.event.AsyncResult() + self.pipe.put((IPCMessageType.EXECUTE.value, (nonce, raw))) + return result + + def get(self, path): + nonce = get_random_str(32) + self.results[nonce] = result = gevent.event.AsyncResult() + self.pipe.put((IPCMessageType.GET_ATTR.value, (nonce, path))) + return result + + def call(self, path, *args, **kwargs): + nonce = get_random_str(32) + self.results[nonce] = result = gevent.event.AsyncResult() + self.pipe.put((IPCMessageType.CALL_FUNC.value, (nonce, path, args, kwargs))) + return result diff --git a/disco/gateway/packets.py b/disco/gateway/packets.py index e78c1ce..a15bfd8 100644 --- a/disco/gateway/packets.py +++ b/disco/gateway/packets.py @@ -1,7 +1,7 @@ from holster.enum import Enum -SEND = object() -RECV = object() +SEND = 1 +RECV = 2 OPCode = Enum( DISPATCH=0, diff --git a/disco/gateway/sharder.py b/disco/gateway/sharder.py new file mode 100644 index 0000000..99e8d1a --- /dev/null +++ b/disco/gateway/sharder.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import + +import gipc +import gevent +import pickle +import logging +import marshal + +from six.moves import range + +from disco.client import Client +from disco.bot import Bot, BotConfig +from disco.api.client import APIClient +from disco.gateway.ipc import GIPCProxy +from disco.util.logging import setup_logging +from disco.util.snowflake import calculate_shard +from disco.util.serializer import dump_function, load_function + + +def run_shard(config, shard_id, pipe): + setup_logging( + level=logging.INFO, + format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(shard_id) + ) + + config.shard_id = shard_id + client = Client(config) + bot = Bot(client, BotConfig(config.bot)) + bot.sharder = GIPCProxy(bot, pipe) + bot.shards = ShardHelper(config.shard_count, bot) + bot.run_forever() + + +class ShardHelper(object): + def __init__(self, count, bot): + self.count = count + self.bot = bot + + def keys(self): + for sid in range(self.count): + yield sid + + def on(self, id, func): + if id == self.bot.client.config.shard_id: + result = gevent.event.AsyncResult() + result.set(func(self.bot)) + return result + + return self.bot.sharder.call(('run_on', ), id, dump_function(func)) + + def all(self, func, timeout=None): + pool = gevent.pool.Pool(self.count) + return dict(zip(range(self.count), pool.imap(lambda i: self.on(i, func).wait(timeout=timeout), range(self.count)))) + + def for_id(self, sid, func): + shard = calculate_shard(self.count, sid) + return self.on(shard, func) + + +class AutoSharder(object): + def __init__(self, config): + self.config = config + self.client = APIClient(config.token) + self.shards = {} + self.config.shard_count = self.client.gateway_bot_get()['shards'] + + def run_on(self, sid, raw): + func = load_function(raw) + return self.shards[sid].execute(func).wait(timeout=15) + + def run(self): + for shard in range(self.config.shard_count): + if self.config.manhole_enable and shard != 0: + self.config.manhole_enable = False + + self.start_shard(shard) + gevent.sleep(6) + + logging.basicConfig( + level=logging.INFO, + format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id) + ) + + @staticmethod + def dumps(data): + if isinstance(data, (basestring, int, long, bool, list, set, dict)): + return '\x01' + marshal.dumps(data) + elif isinstance(data, object) and data.__class__.__name__ == 'code': + return '\x01' + marshal.dumps(data) + else: + return '\x02' + pickle.dumps(data) + + @staticmethod + def loads(data): + enc_type = data[0] + if enc_type == '\x01': + return marshal.loads(data[1:]) + elif enc_type == '\x02': + return pickle.loads(data[1:]) + + def start_shard(self, sid): + cpipe, ppipe = gipc.pipe(duplex=True, encoder=self.dumps, decoder=self.loads) + gipc.start_process(run_shard, (self.config, sid, cpipe)) + self.shards[sid] = GIPCProxy(self, ppipe) diff --git a/disco/state.py b/disco/state.py index ae50698..689ab62 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,10 +1,11 @@ import six +import weakref import inflection from collections import deque, namedtuple -from weakref import WeakValueDictionary from gevent.event import Event +from disco.types.base import UNSET from disco.util.config import Config from disco.util.hashmap import HashMap, DefaultHashMap @@ -88,7 +89,7 @@ class State(object): EVENTS = [ 'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove', 'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete', - 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate', + 'GuildEmojisUpdate', 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate', 'PresenceUpdate' ] @@ -102,9 +103,9 @@ class State(object): self.me = None self.dms = HashMap() self.guilds = HashMap() - self.channels = HashMap(WeakValueDictionary()) - self.users = HashMap(WeakValueDictionary()) - self.voice_states = HashMap(WeakValueDictionary()) + self.channels = HashMap(weakref.WeakValueDictionary()) + self.users = HashMap(weakref.WeakValueDictionary()) + self.voice_states = HashMap(weakref.WeakValueDictionary()) # If message tracking is enabled, listen to those events if self.config.track_messages: @@ -117,7 +118,7 @@ class State(object): def unbind(self): """ - Unbinds all bound event listeners for this state object + Unbinds all bound event listeners for this state object. """ map(lambda k: k.unbind(), self.listeners) self.listeners = [] @@ -185,11 +186,19 @@ class State(object): for member in six.itervalues(event.guild.members): self.users[member.user.id] = member.user + for voice_state in six.itervalues(event.guild.voice_states): + self.voice_states[voice_state.session_id] = voice_state + if self.config.sync_guild_members: event.guild.sync() def on_guild_update(self, event): - self.guilds[event.guild.id].update(event.guild) + self.guilds[event.guild.id].update(event.guild, ignored=[ + 'channels', + 'members', + 'voice_states', + 'presences' + ]) def on_guild_delete(self, event): if event.id in self.guilds: @@ -208,6 +217,10 @@ class State(object): if event.channel.id in self.channels: self.channels[event.channel.id].update(event.channel) + if event.overwrites is not UNSET: + self.channels[event.channel.id].overwrites = event.overwrites + self.channels[event.channel.id].after_load() + def on_channel_delete(self, event): if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels: del event.channel.guild.channels[event.channel.id] @@ -215,18 +228,22 @@ class State(object): del self.dms[event.channel.id] def on_voice_state_update(self, event): - # Happy path: we have the voice state and want to update/delete it - guild = self.guilds.get(event.state.guild_id) - if not guild: - return - - if event.state.session_id in guild.voice_states: + # Existing connection, we are either moving channels or disconnecting + if event.state.session_id in self.voice_states: + # Moving channels if event.state.channel_id: - guild.voice_states[event.state.session_id].update(event.state) + self.voice_states[event.state.session_id].update(event.state) + # Disconnection else: - del guild.voice_states[event.state.session_id] + if event.state.guild_id in self.guilds: + if event.state.session_id in self.guilds[event.state.guild_id].voice_states: + del self.guilds[event.state.guild_id].voice_states[event.state.session_id] + del self.voice_states[event.state.session_id] + # New connection elif event.state.channel_id: - guild.voice_states[event.state.session_id] = event.state + if event.state.guild_id in self.guilds: + self.guilds[event.state.guild_id].voice_states[event.state.session_id] = event.state + self.voice_states[event.state.session_id] = event.state def on_guild_member_add(self, event): if event.member.user.id not in self.users: @@ -243,6 +260,9 @@ class State(object): if event.member.guild_id not in self.guilds: return + if event.member.id not in self.guilds[event.member.guild_id].members: + return + self.guilds[event.member.guild_id].members[event.member.id].update(event.member) def on_guild_member_remove(self, event): @@ -285,6 +305,22 @@ class State(object): del self.guilds[event.guild_id].roles[event.role_id] + def on_guild_emojis_update(self, event): + if event.guild_id not in self.guilds: + return + + self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis}) + def on_presence_update(self, event): if event.user.id in self.users: + self.users[event.user.id].update(event.presence.user) self.users[event.user.id].presence = event.presence + event.presence.user = self.users[event.user.id] + + if event.guild_id not in self.guilds: + return + + if event.user.id not in self.guilds[event.guild_id].members: + return + + self.guilds[event.guild_id].members[event.user.id].user.update(event.user) diff --git a/disco/types/__init__.py b/disco/types/__init__.py index 5e6f73b..5824ec5 100644 --- a/disco/types/__init__.py +++ b/disco/types/__init__.py @@ -1,3 +1,4 @@ +from disco.types.base import UNSET from disco.types.channel import Channel from disco.types.guild import Guild, GuildMember, Role from disco.types.user import User diff --git a/disco/types/base.py b/disco/types/base.py index 63bc628..c3a061b 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -3,7 +3,7 @@ import gevent import inspect import functools -from holster.enum import BaseEnumMeta +from holster.enum import BaseEnumMeta, EnumAttr from datetime import datetime as real_datetime from disco.util.functional import CachedSlotProperty @@ -15,45 +15,61 @@ DATETIME_FORMATS = [ ] +def get_item_by_path(obj, path): + for part in path.split('.'): + obj = getattr(obj, part) + return obj + + +class Unset(object): + def __nonzero__(self): + return False + + +UNSET = Unset() + + class ConversionError(Exception): def __init__(self, field, raw, e): super(ConversionError, self).__init__( 'Failed to convert `{}` (`{}`) to {}: {}'.format( - str(raw)[:144], field.src_name, field.typ, e)) + str(raw)[:144], field.src_name, field.true_type, e)) + if six.PY3: + self.__cause__ = e -class FieldType(object): - def __init__(self, typ): - if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model): - self.typ = typ - elif isinstance(typ, BaseEnumMeta): - self.typ = lambda raw, _: typ.get(raw) - elif typ is None: - self.typ = lambda x, y: None - else: - self.typ = lambda raw, _: typ(raw) - def try_convert(self, raw, client): - pass - - def __call__(self, raw, client): - return self.try_convert(raw, client) +class Field(object): + def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, cast=None, **kwargs): + # TODO: fix default bullshit + self.true_type = value_type + self.src_name = alias + self.dst_name = None + self.ignore_dump = ignore_dump or [] + self.cast = cast + self.metadata = kwargs + if default is not None: + self.default = default + elif not hasattr(self, 'default'): + self.default = None -class Field(FieldType): - def __init__(self, typ, alias=None, default=None): - super(Field, self).__init__(typ) + self.deserializer = None - # Set names - self.src_name = alias - self.dst_name = None + if value_type: + self.deserializer = self.type_to_deserializer(value_type) - self.default = default + if isinstance(self.deserializer, Field) and self.default is None: + self.default = self.deserializer.default + elif inspect.isclass(self.deserializer) and issubclass(self.deserializer, Model) and self.default is None and create: + self.default = self.deserializer - if isinstance(self.typ, FieldType): - self.default = self.typ.default + @property + def name(self): + return None - def set_name(self, name): + @name.setter + def name(self, name): if not self.dst_name: self.dst_name = name @@ -65,31 +81,82 @@ class Field(FieldType): def try_convert(self, raw, client): try: - return self.typ(raw, client) + return self.deserializer(raw, client) except Exception as e: - six.raise_from(ConversionError(self, raw, e), e) + six.reraise(ConversionError, ConversionError(self, raw, e)) + @staticmethod + def type_to_deserializer(typ): + if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model): + return typ + elif isinstance(typ, BaseEnumMeta): + return lambda raw, _: typ.get(raw) + elif typ is None: + return lambda x, y: None + else: + return lambda raw, _: typ(raw) + + @staticmethod + def serialize(value, inst=None): + if isinstance(value, EnumAttr): + return value.value + elif isinstance(value, Model): + return value.to_dict(ignore=(inst.ignore_dump if inst else [])) + else: + if inst and inst.cast: + return inst.cast(value) + return value -class _Dict(FieldType): + def __call__(self, raw, client): + return self.try_convert(raw, client) + + +class DictField(Field): default = HashMap - def __init__(self, typ, key=None): - super(_Dict, self).__init__(typ) - self.key = key + def __init__(self, key_type, value_type=None, **kwargs): + super(DictField, self).__init__({}, **kwargs) + self.true_key_type = key_type + self.true_value_type = value_type + self.key_de = self.type_to_deserializer(key_type) + self.value_de = self.type_to_deserializer(value_type or key_type) + + @staticmethod + def serialize(value, inst=None): + return { + Field.serialize(k): Field.serialize(v) for k, v in six.iteritems(value) + if k not in (inst.ignore_dump if inst else []) + } def try_convert(self, raw, client): - if self.key: - converted = [self.typ(i, client) for i in raw] - return HashMap({getattr(i, self.key): i for i in converted}) - else: - return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)}) + return HashMap({ + self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) + }) -class _List(FieldType): +class ListField(Field): default = list + @staticmethod + def serialize(value, inst=None): + return list(map(Field.serialize, value)) + + def try_convert(self, raw, client): + return [self.deserializer(i, client) for i in raw] + + +class AutoDictField(Field): + default = HashMap + + def __init__(self, value_type, key, **kwargs): + super(AutoDictField, self).__init__({}, **kwargs) + self.value_de = self.type_to_deserializer(value_type) + self.key = key + def try_convert(self, raw, client): - return [self.typ(i, client) for i in raw] + return HashMap({ + getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw) + }) def _make(typ, data, client): @@ -104,37 +171,19 @@ def snowflake(data): def enum(typ): def _f(data): + if isinstance(data, str): + data = data.lower() return typ.get(data) if data is not None else None return _f -def listof(*args, **kwargs): - return _List(*args, **kwargs) - - -def dictof(*args, **kwargs): - return _Dict(*args, **kwargs) - - -def lazy_datetime(data): - if not data: - return property(lambda: None) - - def get(): - for fmt in DATETIME_FORMATS: - try: - return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) - except (ValueError, TypeError): - continue - raise ValueError('Failed to conver `{}` to datetime'.format(data)) - - return property(get) - - def datetime(data): if not data: return None + if isinstance(data, int): + return real_datetime.utcfromtimestamp(data) + for fmt in DATETIME_FORMATS: try: return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) @@ -145,6 +194,9 @@ def datetime(data): def text(obj): + if obj is None: + return None + if six.PY2: if isinstance(obj, str): return obj.decode('utf-8') @@ -154,6 +206,9 @@ def text(obj): def binary(obj): + if obj is None: + return None + if six.PY2: if isinstance(obj, str): return obj.decode('utf-8') @@ -165,13 +220,16 @@ def binary(obj): def with_equality(field): class T(object): def __eq__(self, other): - return getattr(self, field) == getattr(other, field) + if isinstance(other, self.__class__): + return getattr(self, field) == getattr(other, field) + else: + return getattr(self, field) == other return T def with_hash(field): class T(object): - def __hash__(self, other): + def __hash__(self): return hash(getattr(self, field)) return T @@ -182,7 +240,7 @@ SlottedModel = None class ModelMeta(type): - def __new__(cls, name, parents, dct): + def __new__(mcs, name, parents, dct): fields = {} for parent in parents: @@ -193,7 +251,7 @@ class ModelMeta(type): if not isinstance(v, Field): continue - v.set_name(k) + v.name = k fields[k] = v if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)): @@ -209,7 +267,7 @@ class ModelMeta(type): dct = {k: v for k, v in six.iteritems(dct) if k not in fields} dct['_fields'] = fields - return super(ModelMeta, cls).__new__(cls, name, parents, dct) + return super(ModelMeta, mcs).__new__(mcs, name, parents, dct) class AsyncChainable(object): @@ -233,23 +291,49 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): else: obj = kwargs - for name, field in six.iteritems(self.__class__._fields): - if field.src_name not in obj or obj[field.src_name] is None: - if field.has_default(): - default = field.default() if callable(field.default) else field.default - else: - default = None - setattr(self, field.dst_name, default) + self.load(obj) + self.validate() + + def validate(self): + pass + + @property + def _fields(self): + return self.__class__._fields + + def load(self, obj, consume=False, skip=None): + return self.load_into(self, obj, consume, skip) + + def load_into(self, inst, obj, consume=False, skip=None): + for name, field in six.iteritems(self._fields): + should_skip = skip and name in skip + + if consume and not should_skip: + raw = obj.pop(field.src_name, UNSET) + else: + raw = obj.get(field.src_name, UNSET) + + # If the field is unset/none, and we have a default we need to set it + if (raw in (None, UNSET) or should_skip) and field.has_default(): + default = field.default() if callable(field.default) else field.default + setattr(inst, field.dst_name, default) + continue + + # Otherwise if the field is UNSET and has no default, skip conversion + if raw is UNSET or should_skip: + setattr(inst, field.dst_name, raw) continue - value = field.try_convert(obj[field.src_name], self.client) - setattr(self, field.dst_name, value) + value = field.try_convert(raw, self.client) + setattr(inst, field.dst_name, value) - def update(self, other): - for name in six.iterkeys(self.__class__._fields): - value = getattr(other, name) - if value: - setattr(self, name, value) + def update(self, other, ignored=None): + for name in six.iterkeys(self._fields): + if ignored and name in ignored: + continue + + if hasattr(other, name) and not getattr(other, name) is UNSET: + setattr(self, name, getattr(other, name)) # Clear cached properties for name in dir(type(self)): @@ -259,8 +343,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): except: pass - def to_dict(self): - return {k: getattr(self, k) for k in six.iterkeys(self.__class__._fields)} + def to_dict(self, ignore=None): + obj = {} + for name, field in six.iteritems(self.__class__._fields): + if ignore and name in ignore: + continue + + if getattr(self, name) == UNSET: + continue + obj[name] = field.serialize(getattr(self, name), field) + return obj @classmethod def create(cls, client, data, **kwargs): @@ -269,8 +361,16 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)): return inst @classmethod - def create_map(cls, client, data): - return list(map(functools.partial(cls.create, client), data)) + def create_map(cls, client, data, **kwargs): + return list(map(functools.partial(cls.create, client, **kwargs), data)) + + @classmethod + def create_hash(cls, client, key, data, **kwargs): + return HashMap({ + get_item_by_path(item, key): item + for item in [ + cls.create(client, item, **kwargs) for item in data] + }) @classmethod def attach(cls, it, data): diff --git a/disco/types/channel.py b/disco/types/channel.py index f824a35..311ca5c 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,11 +1,12 @@ import six +from six.moves import map from holster.enum import Enum from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property, one_or_many, chunks from disco.types.user import User -from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text +from disco.types.base import SlottedModel, Field, AutoDictField, snowflake, enum, text from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.voice.client import VoiceClient @@ -33,7 +34,7 @@ class ChannelSubType(SlottedModel): class PermissionOverwrite(ChannelSubType): """ - A PermissionOverwrite for a :class:`Channel` + A PermissionOverwrite for a :class:`Channel`. Attributes ---------- @@ -48,8 +49,8 @@ class PermissionOverwrite(ChannelSubType): """ id = Field(snowflake) type = Field(enum(PermissionOverwriteType)) - allow = Field(PermissionValue) - deny = Field(PermissionValue) + allow = Field(PermissionValue, cast=int) + deny = Field(PermissionValue, cast=int) channel_id = Field(snowflake) @@ -57,22 +58,29 @@ class PermissionOverwrite(ChannelSubType): def create(cls, channel, entity, allow=0, deny=0): from disco.types.guild import Role - type = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER + ptype = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER return cls( client=channel.client, id=entity.id, - type=type, + type=ptype, allow=allow, deny=deny, channel_id=channel.id ).save() + @property + def compiled(self): + value = PermissionValue() + value -= self.deny + value += self.allow + return value + def save(self): self.client.api.channels_permissions_modify(self.channel_id, - self.id, - self.allow.value or 0, - self.deny.value or 0, - self.type.name) + self.id, + self.allow.value or 0, + self.deny.value or 0, + self.type.name) return self def delete(self): @@ -81,7 +89,7 @@ class PermissionOverwrite(ChannelSubType): class Channel(SlottedModel, Permissible): """ - Represents a Discord Channel + Represents a Discord Channel. Attributes ---------- @@ -111,18 +119,27 @@ class Channel(SlottedModel, Permissible): last_message_id = Field(snowflake) position = Field(int) bitrate = Field(int) - recipients = Field(listof(User)) + recipients = AutoDictField(User, 'id') type = Field(enum(ChannelType)) - overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') + overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') def __init__(self, *args, **kwargs): super(Channel, self).__init__(*args, **kwargs) + self.after_load() + def after_load(self): + # TODO: hackfix self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) + def __str__(self): + return u'#{}'.format(self.name) + + def __repr__(self): + return u''.format(self.id, self) + def get_permissions(self, user): """ - Get the permissions a user has in the channel + Get the permissions a user has in the channel. Returns ------- @@ -132,8 +149,8 @@ class Channel(SlottedModel, Permissible): if not self.guild_id: return Permissions.ADMINISTRATOR - member = self.guild.members.get(user.id) - base = self.guild.get_permissions(user) + member = self.guild.get_member(user) + base = self.guild.get_permissions(member) for ow in six.itervalues(self.overwrites): if ow.id != user.id and ow.id not in member.roles: @@ -144,48 +161,55 @@ class Channel(SlottedModel, Permissible): return base + @property + def mention(self): + return '<#{}>'.format(self.id) + @property def is_guild(self): """ - Whether this channel belongs to a guild + Whether this channel belongs to a guild. """ return self.type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE) @property def is_dm(self): """ - Whether this channel is a DM (does not belong to a guild) + Whether this channel is a DM (does not belong to a guild). """ return self.type in (ChannelType.DM, ChannelType.GROUP_DM) @property def is_voice(self): """ - Whether this channel supports voice + Whether this channel supports voice. """ return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM) @property def messages(self): """ - a default :class:`MessageIterator` for the channel + a default :class:`MessageIterator` for the channel. """ return self.messages_iter() @cached_property def guild(self): """ - Guild this channel belongs to (if relevant) + Guild this channel belongs to (if relevant). """ return self.client.state.guilds.get(self.guild_id) def messages_iter(self, **kwargs): """ Creates a new :class:`MessageIterator` for the channel with the given - keyword arguments + keyword arguments. """ return MessageIterator(self.client, self, **kwargs) + def get_message(self, message): + return self.client.api.channels_messages_get(self.id, to_snowflake(message)) + def get_invites(self): """ Returns @@ -220,9 +244,9 @@ class Channel(SlottedModel, Permissible): def create_webhook(self, name=None, avatar=None): return self.client.api.channels_webhooks_create(self.id, name, avatar) - def send_message(self, content, nonce=None, tts=False): + def send_message(self, content, nonce=None, tts=False, attachment=None, embed=None): """ - Send a message in this channel + Send a message in this channel. Parameters ---------- @@ -238,11 +262,11 @@ class Channel(SlottedModel, Permissible): :class:`disco.types.message.Message` The created message. """ - return self.client.api.channels_messages_create(self.id, content, nonce, tts) + return self.client.api.channels_messages_create(self.id, content, nonce, tts, attachment, embed) def connect(self, *args, **kwargs): """ - Connect to this channel over voice + Connect to this channel over voice. """ assert self.is_voice, 'Channel must support voice to connect' vc = VoiceClient(self) @@ -275,17 +299,29 @@ class Channel(SlottedModel, Permissible): List of messages (or message ids) to delete. All messages must originate from this channel. """ - messages = map(to_snowflake, messages) + message_ids = list(map(to_snowflake, messages)) - if not messages: + if not message_ids: return - if len(messages) <= 2: + if self.can(self.client.state.me, Permissions.MANAGE_MESSAGES) and len(messages) > 2: + for chunk in chunks(message_ids, 100): + self.client.api.channels_messages_delete_bulk(self.id, chunk) + else: for msg in messages: self.delete_message(msg) - else: - for chunk in chunks(messages, 100): - self.client.api.channels_messages_delete_bulk(self.id, chunk) + + def delete(self): + assert (self.is_dm or self.guild.can(self.client.state.me, Permissions.MANAGE_GUILD)), 'Invalid Permissions' + self.client.api.channels_delete(self.id) + + def close(self): + """ + Closes a DM channel. This is intended as a safer version of `delete`, + enforcing that the channel is actually a DM. + """ + assert self.is_dm, 'Cannot close non-DM channel' + self.delete() class MessageIterator(object): @@ -329,7 +365,7 @@ class MessageIterator(object): def fill(self): """ - Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API + Fills the internal buffer up with :class:`disco.types.message.Message` objects from the API. """ self._buffer = self.client.api.channels_messages_list( self.channel.id, diff --git a/disco/types/guild.py b/disco/types/guild.py index 9708e3e..f93a88f 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -6,10 +6,13 @@ from disco.gateway.packets import OPCode from disco.api.http import APIException from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property -from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum -from disco.types.user import User +from disco.types.base import ( + SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum, datetime +) +from disco.types.user import User, Presence from disco.types.voice import VoiceState from disco.types.channel import Channel +from disco.types.message import Emoji from disco.types.permissions import PermissionValue, Permissions, Permissible @@ -18,21 +21,12 @@ VerificationLevel = Enum( LOW=1, MEDIUM=2, HIGH=3, - EXTREME=4, ) -class GuildSubType(SlottedModel): - guild_id = Field(None) - - @cached_property - def guild(self): - return self.client.state.guilds.get(self.guild_id) - - -class Emoji(GuildSubType): +class GuildEmoji(Emoji): """ - An emoji object + An emoji object. Attributes ---------- @@ -48,15 +42,27 @@ class Emoji(GuildSubType): Roles this emoji is attached to. """ id = Field(snowflake) + guild_id = Field(snowflake) name = Field(text) require_colons = Field(bool) managed = Field(bool) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) + def __str__(self): + return u'<:{}:{}>'.format(self.name, self.id) -class Role(GuildSubType): + @property + def url(self): + return 'https://discordapp.com/api/emojis/{}.png'.format(self.id) + + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + + +class Role(SlottedModel): """ - A role object + A role object. Attributes ---------- @@ -76,6 +82,7 @@ class Role(GuildSubType): The position of this role in the hierarchy. """ id = Field(snowflake) + guild_id = Field(snowflake) name = Field(text) hoist = Field(bool) managed = Field(bool) @@ -84,6 +91,9 @@ class Role(GuildSubType): position = Field(int) mentionable = Field(bool) + def __str__(self): + return self.name + def delete(self): self.guild.delete_role(self) @@ -94,10 +104,19 @@ class Role(GuildSubType): def mention(self): return '<@{}>'.format(self.id) + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + -class GuildMember(GuildSubType): +class GuildBan(SlottedModel): + user = Field(User) + reason = Field(str) + + +class GuildMember(SlottedModel): """ - A GuildMember object + A GuildMember object. Attributes ---------- @@ -121,8 +140,18 @@ class GuildMember(GuildSubType): nick = Field(text) mute = Field(bool) deaf = Field(bool) - joined_at = Field(str) - roles = Field(listof(snowflake)) + joined_at = Field(datetime) + roles = ListField(snowflake) + + def __str__(self): + return self.user.__str__() + + @property + def name(self): + """ + The nickname of this user if set, otherwise their username + """ + return self.nick or self.user.username def get_voice_state(self): """ @@ -151,6 +180,12 @@ class GuildMember(GuildSubType): """ self.guild.create_ban(self, delete_message_days) + def unban(self): + """ + Unbans the member from the guild. + """ + self.guild.delete_ban(self) + def set_nickname(self, nickname=None): """ Sets the member's nickname (or clears it if None). @@ -160,11 +195,19 @@ class GuildMember(GuildSubType): nickname : Optional[str] The nickname (or none to reset) to set. """ - self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') + if self.client.state.me.id == self.user.id: + self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '') + else: + self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') + + def modify(self, **kwargs): + self.client.api.guilds_members_modify(self.guild.id, self.user.id, **kwargs) def add_role(self, role): - roles = self.roles + [role.id] - self.client.api.guilds_members_modify(self.guild.id, self.user.id, roles=roles) + self.client.api.guilds_members_roles_add(self.guild.id, self.user.id, to_snowflake(role)) + + def remove_role(self, role): + self.client.api.guilds_members_roles_remove(self.guild.id, self.user.id, to_snowflake(role)) @cached_property def owner(self): @@ -179,14 +222,22 @@ class GuildMember(GuildSubType): @property def id(self): """ - Alias to the guild members user id + Alias to the guild members user id. """ return self.user.id + @cached_property + def guild(self): + return self.client.state.guilds.get(self.guild_id) + + @cached_property + def permissions(self): + return self.guild.get_permissions(self) + class Guild(SlottedModel, Permissible): """ - A guild object + A guild object. Attributes ---------- @@ -222,7 +273,7 @@ class Guild(SlottedModel, Permissible): All of the guild's channels. roles : dict(snowflake, :class:`Role`) All of the guild's roles. - emojis : dict(snowflake, :class:`Emoji`) + emojis : dict(snowflake, :class:`GuildEmoji`) All of the guild's emojis. voice_states : dict(str, :class:`disco.types.voice.VoiceState`) All of the guild's voice states. @@ -239,12 +290,14 @@ class Guild(SlottedModel, Permissible): embed_enabled = Field(bool) verification_level = Field(enum(VerificationLevel)) mfa_level = Field(int) - features = Field(listof(str)) - members = Field(dictof(GuildMember, key='id')) - channels = Field(dictof(Channel, key='id')) - roles = Field(dictof(Role, key='id')) - emojis = Field(dictof(Emoji, key='id')) - voice_states = Field(dictof(VoiceState, key='session_id')) + features = ListField(str) + members = AutoDictField(GuildMember, 'id') + channels = AutoDictField(Channel, 'id') + roles = AutoDictField(Role, 'id') + emojis = AutoDictField(GuildEmoji, 'id') + voice_states = AutoDictField(VoiceState, 'session_id') + member_count = Field(int) + presences = ListField(Presence) synced = Field(bool, default=False) @@ -257,7 +310,7 @@ class Guild(SlottedModel, Permissible): self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) - def get_permissions(self, user): + def get_permissions(self, member): """ Get the permissions a user has in this guild. @@ -266,10 +319,13 @@ class Guild(SlottedModel, Permissible): :class:`disco.types.permissions.PermissionValue` Computed permission value for the user. """ - if self.owner_id == user.id: + if not isinstance(member, GuildMember): + member = self.get_member(member) + + # Owner has all permissions + if self.owner_id == member.id: return PermissionValue(Permissions.ADMINISTRATOR) - member = self.get_member(user) value = PermissionValue(self.roles.get(self.id).permissions) for role in map(self.roles.get, member.roles): @@ -358,3 +414,6 @@ class Guild(SlottedModel, Permissible): def create_ban(self, user, delete_message_days=0): self.client.api.guilds_bans_create(self.id, to_snowflake(user), delete_message_days) + + def create_channel(self, *args, **kwargs): + return self.client.api.guilds_channels_create(self.id, *args, **kwargs) diff --git a/disco/types/invite.py b/disco/types/invite.py index 850002e..0cc4852 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,4 +1,4 @@ -from disco.types.base import SlottedModel, Field, lazy_datetime +from disco.types.base import SlottedModel, Field, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel @@ -6,7 +6,7 @@ from disco.types.channel import Channel class Invite(SlottedModel): """ - An invite object + An invite object. Attributes ---------- @@ -37,7 +37,7 @@ class Invite(SlottedModel): max_uses = Field(int) uses = Field(int) temporary = Field(bool) - created_at = Field(lazy_datetime) + created_at = Field(datetime) @classmethod def create(cls, channel, max_age=86400, max_uses=0, temporary=False, unique=False): diff --git a/disco/types/message.py b/disco/types/message.py index 15f751d..9b33d11 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,8 +1,14 @@ import re +import six +import functools +import unicodedata from holster.enum import Enum -from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum +from disco.types.base import ( + SlottedModel, Field, ListField, AutoDictField, snowflake, text, + datetime, enum +) from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -19,6 +25,31 @@ MessageType = Enum( ) +class Emoji(SlottedModel): + id = Field(snowflake) + name = Field(text) + + def __eq__(self, other): + if isinstance(other, Emoji): + return self.id == other.id and self.name == other.name + raise NotImplementedError + + def to_string(self): + if self.id: + return '{}:{}'.format(self.name, self.id) + return self.name + + +class MessageReactionEmoji(Emoji): + pass + + +class MessageReaction(SlottedModel): + emoji = Field(MessageReactionEmoji) + count = Field(int) + me = Field(bool) + + class MessageEmbedFooter(SlottedModel): text = Field(text) icon_url = Field(text) @@ -60,7 +91,7 @@ class MessageEmbedField(SlottedModel): class MessageEmbed(SlottedModel): """ - Message embed object + Message embed object. Attributes ---------- @@ -76,20 +107,38 @@ class MessageEmbed(SlottedModel): title = Field(text) type = Field(str, default='rich') description = Field(text) - url = Field(str) - timestamp = Field(lazy_datetime) + url = Field(text) + timestamp = Field(datetime) color = Field(int) footer = Field(MessageEmbedFooter) image = Field(MessageEmbedImage) thumbnail = Field(MessageEmbedThumbnail) video = Field(MessageEmbedVideo) author = Field(MessageEmbedAuthor) - fields = Field(listof(MessageEmbedField)) + fields = ListField(MessageEmbedField) + + def set_footer(self, *args, **kwargs): + self.footer = MessageEmbedFooter(*args, **kwargs) + + def set_image(self, *args, **kwargs): + self.image = MessageEmbedImage(*args, **kwargs) + + def set_thumbnail(self, *args, **kwargs): + self.thumbnail = MessageEmbedThumbnail(*args, **kwargs) + + def set_video(self, *args, **kwargs): + self.video = MessageEmbedVideo(*args, **kwargs) + + def set_author(self, *args, **kwargs): + self.author = MessageEmbedAuthor(*args, **kwargs) + + def add_field(self, *args, **kwargs): + self.fields.append(MessageEmbedField(*args, **kwargs)) class MessageAttachment(SlottedModel): """ - Message attachment object + Message attachment object. Attributes ---------- @@ -110,8 +159,8 @@ class MessageAttachment(SlottedModel): """ id = Field(str) filename = Field(text) - url = Field(str) - proxy_url = Field(str) + url = Field(text) + proxy_url = Field(text) size = Field(int) height = Field(int) width = Field(int) @@ -161,15 +210,16 @@ class Message(SlottedModel): author = Field(User) content = Field(text) nonce = Field(snowflake) - timestamp = Field(lazy_datetime) - edited_timestamp = Field(lazy_datetime) + timestamp = Field(datetime) + edited_timestamp = Field(datetime) tts = Field(bool) mention_everyone = Field(bool) pinned = Field(bool) - mentions = Field(dictof(User, key='id')) - mention_roles = Field(listof(snowflake)) - embeds = Field(listof(MessageEmbed)) - attachments = Field(dictof(MessageAttachment, key='id')) + mentions = AutoDictField(User, 'id') + mention_roles = ListField(snowflake) + embeds = ListField(MessageEmbed) + attachments = AutoDictField(MessageAttachment, 'id') + reactions = ListField(MessageReaction) def __str__(self): return ''.format(self.id, self.channel_id) @@ -213,7 +263,7 @@ class Message(SlottedModel): def reply(self, *args, **kwargs): """ Reply to this message (proxys arguments to - :func:`disco.types.channel.Channel.send_message`) + :func:`disco.types.channel.Channel.send_message`). Returns ------- @@ -222,9 +272,9 @@ class Message(SlottedModel): """ return self.channel.send_message(*args, **kwargs) - def edit(self, content): + def edit(self, *args, **kwargs): """ - Edit this message + Edit this message. Args ---- @@ -236,7 +286,7 @@ class Message(SlottedModel): :class:`Message` The edited message object. """ - return self.client.api.channels_messages_modify(self.channel_id, self.id, content) + return self.client.api.channels_messages_modify(self.channel_id, self.id, *args, **kwargs) def delete(self): """ @@ -249,6 +299,42 @@ class Message(SlottedModel): """ return self.client.api.channels_messages_delete(self.channel_id, self.id) + def get_reactors(self, emoji): + """ + Returns an list of users who reacted to this message with the given emoji. + + Returns + ------- + list(:class:`User`) + The users who reacted. + """ + return self.client.api.channels_messages_reactions_get( + self.channel_id, + self.id, + emoji + ) + + def create_reaction(self, emoji): + if isinstance(emoji, Emoji): + emoji = emoji.to_string() + self.client.api.channels_messages_reactions_create( + self.channel_id, + self.id, + emoji) + + def delete_reaction(self, emoji, user=None): + if isinstance(emoji, Emoji): + emoji = emoji.to_string() + + if user: + user = to_snowflake(user) + + self.client.api.channels_messages_reactions_delete( + self.channel_id, + self.id, + emoji, + user) + def is_mentioned(self, entity): """ Returns @@ -256,22 +342,37 @@ class Message(SlottedModel): bool Whether the give entity was mentioned. """ - id = to_snowflake(entity) - return id in self.mentions or id in self.mention_roles + entity = to_snowflake(entity) + return entity in self.mentions or entity in self.mention_roles @cached_property - def without_mentions(self): + def without_mentions(self, valid_only=False): """ Returns ------- str - the message contents with all valid mentions removed. + the message contents with all mentions removed. """ return self.replace_mentions( lambda u: '', - lambda r: '') + lambda r: '', + lambda c: '', + nonexistant=not valid_only) + + @cached_property + def with_proper_mentions(self): + def replace_user(u): + return u'@' + six.text_type(u) + + def replace_role(r): + return u'@' + six.text_type(r) + + def replace_channel(c): + return six.text_type(c) - def replace_mentions(self, user_replace, role_replace): + return self.replace_mentions(replace_user, replace_role, replace_channel) + + def replace_mentions(self, user_replace=None, role_replace=None, channel_replace=None, nonexistant=False): """ Replaces user and role mentions with the result of a given lambda/function. @@ -289,39 +390,55 @@ class Message(SlottedModel): str The message contents with all valid mentions replaced. """ - if not self.mentions and not self.mention_roles: - return + def replace(getter, func, match): + oid = int(match.group(2)) + obj = getter(oid) + + if obj or nonexistant: + return func(obj or oid) or match.group(0) + + return match.group(0) + + content = self.content + + if user_replace: + replace_user = functools.partial(replace, self.mentions.get, user_replace) + content = re.sub('(<@!?([0-9]+)>)', replace_user, content) + + if role_replace: + replace_role = functools.partial(replace, lambda v: (self.guild and self.guild.roles.get(v)), role_replace) + content = re.sub('(<@&([0-9]+)>)', replace_role, content) - def replace(match): - id = match.group(0) - if id in self.mention_roles: - return role_replace(id) - else: - return user_replace(self.mentions.get(id)) + if channel_replace: + replace_channel = functools.partial(replace, self.client.state.channels.get, channel_replace) + content = re.sub('(<#([0-9]+)>)', replace_channel, content) - return re.sub('<@!?([0-9]+)>', replace, self.content) + return content class MessageTable(object): - def __init__(self, sep=' | ', codeblock=True, header_break=True): + def __init__(self, sep=' | ', codeblock=True, header_break=True, language=None): self.header = [] self.entries = [] self.size_index = {} self.sep = sep self.codeblock = codeblock self.header_break = header_break + self.language = language def recalculate_size_index(self, cols): for idx, col in enumerate(cols): - if idx not in self.size_index or len(col) > self.size_index[idx]: - self.size_index[idx] = len(col) + size = len(unicodedata.normalize('NFC', col)) + if idx not in self.size_index or size > self.size_index[idx]: + self.size_index[idx] = size def set_header(self, *args): + args = list(map(six.text_type, args)) self.header = args self.recalculate_size_index(args) def add(self, *args): - args = list(map(str, args)) + args = list(map(six.text_type, args)) self.entries.append(args) self.recalculate_size_index(args) @@ -329,22 +446,23 @@ class MessageTable(object): data = self.sep.lstrip() for idx, col in enumerate(cols): - padding = ' ' * ((self.size_index[idx] - len(col))) + padding = ' ' * (self.size_index[idx] - len(col)) data += col + padding + self.sep return data.rstrip() def compile(self): data = [] - data.append(self.compile_one(self.header)) + if self.header: + data = [self.compile_one(self.header)] - if self.header_break: + if self.header and self.header_break: data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1)) for row in self.entries: data.append(self.compile_one(row)) if self.codeblock: - return '```' + '\n'.join(data) + '```' + return '```{}'.format(self.language if self.language else '') + '\n'.join(data) + '```' return '\n'.join(data) diff --git a/disco/types/permissions.py b/disco/types/permissions.py index aa7260c..6e4d9f3 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -76,13 +76,13 @@ class PermissionValue(object): return self.sub(other) def __getattribute__(self, name): - if name in Permissions.attrs: + if name in Permissions.keys_: return (self.value & Permissions[name].value) == Permissions[name].value else: return object.__getattribute__(self, name) def __setattr__(self, name, value): - if name not in Permissions.attrs: + if name not in Permissions.keys_: return super(PermissionValue, self).__setattr__(name, value) if value: @@ -90,9 +90,12 @@ class PermissionValue(object): else: self.value &= ~Permissions[name].value + def __int__(self): + return self.value + def to_dict(self): return { - k: getattr(self, k) for k in Permissions.attrs + k: getattr(self, k) for k in Permissions.keys_ } @classmethod @@ -107,6 +110,9 @@ class PermissionValue(object): class Permissible(object): __slots__ = [] + def get_permissions(self): + raise NotImplementedError + def can(self, user, *args): perms = self.get_permissions(user) return perms.administrator or perms.can(*args) diff --git a/disco/types/user.py b/disco/types/user.py index cad2cb8..3192abc 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -2,30 +2,54 @@ from holster.enum import Enum from disco.types.base import SlottedModel, Field, snowflake, text, binary, with_equality, with_hash +DefaultAvatars = Enum( + BLURPLE=0, + GREY=1, + GREEN=2, + ORANGE=3, + RED=4, +) + class User(SlottedModel, with_equality('id'), with_hash('id')): id = Field(snowflake) username = Field(text) avatar = Field(binary) discriminator = Field(str) - bot = Field(bool) + bot = Field(bool, default=False) verified = Field(bool) email = Field(str) presence = Field(None) + def get_avatar_url(self, fmt='webp', size=1024): + if not self.avatar: + return 'https://cdn.discordapp.com/embed/avatars/{}.png'.format(self.default_avatar.value) + + return 'https://cdn.discordapp.com/avatars/{}/{}.{}?size={}'.format( + self.id, + self.avatar, + fmt, + size + ) + + @property + def default_avatar(self): + return DefaultAvatars[int(self.discriminator) % len(DefaultAvatars.attrs)] + + @property + def avatar_url(self): + return self.get_avatar_url() + @property def mention(self): return '<@{}>'.format(self.id) - def to_string(self): - return '{}#{}'.format(self.username, self.discriminator) - def __str__(self): - return ''.format(self.id, self.to_string()) + return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4)) - def on_create(self): - self.client.state.users[self.id] = self + def __repr__(self): + return u''.format(self.id, self) GameType = Enum( @@ -49,6 +73,6 @@ class Game(SlottedModel): class Presence(SlottedModel): - user = Field(User) + user = Field(User, alias='user', ignore_dump=['presence']) game = Field(Game) status = Field(Status) diff --git a/disco/types/voice.py b/disco/types/voice.py index 1647eb3..3d7cb32 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -17,7 +17,7 @@ class VoiceState(SlottedModel): def guild(self): return self.client.state.guilds.get(self.guild_id) - @cached_property + @property def channel(self): return self.client.state.channels.get(self.channel_id) diff --git a/disco/types/webhook.py b/disco/types/webhook.py index 3afdd3f..4a630d3 100644 --- a/disco/types/webhook.py +++ b/disco/types/webhook.py @@ -32,12 +32,14 @@ class Webhook(SlottedModel): else: return self.client.api.webhooks_modify(self.id, name, avatar) - def execute(self, content=None, username=None, avatar_url=None, tts=False, file=None, embeds=[], wait=False): + def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False): + # TODO: support file stuff properly + return self.client.api.webhooks_token_execute(self.id, self.token, { 'content': content, 'username': username, 'avatar_url': avatar_url, 'tts': tts, - 'file': file, + 'file': fobj, 'embeds': [i.to_dict() for i in embeds], }, wait) diff --git a/disco/util/config.py b/disco/util/config.py index 29147c2..30d2996 100644 --- a/disco/util/config.py +++ b/disco/util/config.py @@ -29,7 +29,7 @@ class Config(object): return inst def from_prefix(self, prefix): - prefix = prefix + '_' + prefix += '_' obj = {} for k, v in six.iteritems(self.__dict__): diff --git a/disco/util/hashmap.py b/disco/util/hashmap.py index ef32647..50cf6f4 100644 --- a/disco/util/hashmap.py +++ b/disco/util/hashmap.py @@ -45,12 +45,12 @@ class HashMap(UserDict): def filter(self, predicate): if not callable(predicate): raise TypeError('predicate must be callable') - return filter(self.values(), predicate) + return filter(predicate, self.values()) def map(self, predicate): if not callable(predicate): raise TypeError('predicate must be callable') - return map(self.values(), predicate) + return map(predicate, self.values()) class DefaultHashMap(defaultdict, HashMap): diff --git a/disco/util/limiter.py b/disco/util/limiter.py index 6992832..ccb7622 100644 --- a/disco/util/limiter.py +++ b/disco/util/limiter.py @@ -17,7 +17,8 @@ class SimpleLimiter(object): gevent.sleep(self.reset_at - time.time()) self.count = 0 self.reset_at = 0 - self.event.set() + if self.event: + self.event.set() self.event = None def check(self): diff --git a/disco/util/logging.py b/disco/util/logging.py index 7feca4d..75e9229 100644 --- a/disco/util/logging.py +++ b/disco/util/logging.py @@ -3,15 +3,28 @@ from __future__ import absolute_import import logging +LEVEL_OVERRIDES = { + 'requests': logging.WARNING +} + +LOG_FORMAT = '[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s' + + +def setup_logging(**kwargs): + kwargs.setdefault('format', LOG_FORMAT) + + logging.basicConfig(**kwargs) + for logger, level in LEVEL_OVERRIDES.items(): + logging.getLogger(logger).setLevel(level) + + class LoggingClass(object): - def __init__(self): - self.log = logging.getLogger(self.__class__.__name__) - - def log_on_error(self, msg, f): - def _f(*args, **kwargs): - try: - return f(*args, **kwargs) - except: - self.log.exception(msg) - raise - return _f + __slots__ = ['_log'] + + @property + def log(self): + try: + return self._log + except AttributeError: + self._log = logging.getLogger(self.__class__.__name__) + return self._log diff --git a/disco/util/serializer.py b/disco/util/serializer.py index 74fe766..a481a9d 100644 --- a/disco/util/serializer.py +++ b/disco/util/serializer.py @@ -1,3 +1,5 @@ +import six +import types class Serializer(object): @@ -36,3 +38,37 @@ class Serializer(object): def dumps(cls, fmt, raw): _, dumps = getattr(cls, fmt)() return dumps(raw) + + +def dump_cell(cell): + return cell.cell_contents + + +def load_cell(cell): + if six.PY3: + return (lambda y: cell).__closure__[0] + else: + return (lambda y: cell).func_closure[0] + + +def dump_function(func): + if six.PY3: + return ( + func.__code__, + func.__name__, + func.__defaults__, + list(map(dump_cell, func.__closure__)) if func.__closure__ else [], + ) + else: + return ( + func.func_code, + func.func_name, + func.func_defaults, + list(map(dump_cell, func.func_closure)) if func.func_closure else [], + ) + + +def load_function(args): + code, name, defaults, closure = args + closure = tuple(map(load_cell, closure)) + return types.FunctionType(code, globals(), name, defaults, closure) diff --git a/disco/util/snowflake.py b/disco/util/snowflake.py index 241e2de..b2f512f 100644 --- a/disco/util/snowflake.py +++ b/disco/util/snowflake.py @@ -17,7 +17,7 @@ def to_unix(snowflake): def to_unix_ms(snowflake): - return ((int(snowflake) >> 22) + DISCORD_EPOCH) + return (int(snowflake) >> 22) + DISCORD_EPOCH def to_snowflake(i): diff --git a/disco/util/token.py b/disco/util/token.py index c48beca..d71b93d 100644 --- a/disco/util/token.py +++ b/disco/util/token.py @@ -5,6 +5,6 @@ TOKEN_RE = re.compile(r'M\w{23}\.[\w-]{6}\..{27}') def is_valid_token(token): """ - Validates a Discord authentication token, returning true if valid + Validates a Discord authentication token, returning true if valid. """ return bool(TOKEN_RE.match(token)) diff --git a/disco/voice/client.py b/disco/voice/client.py index 2ca00ba..327c4cd 100644 --- a/disco/voice/client.py +++ b/disco/voice/client.py @@ -106,6 +106,7 @@ class VoiceClient(LoggingClass): self.endpoint = None self.ssrc = None self.port = None + self.udp = None self.update_listener = None @@ -158,7 +159,7 @@ class VoiceClient(LoggingClass): } }) - def on_voice_sdp(self, data): + def on_voice_sdp(self, _): # Toggle speaking state so clients learn of our SSRC self.set_speaking(True) self.set_speaking(False) @@ -187,11 +188,10 @@ class VoiceClient(LoggingClass): def on_message(self, msg): try: data = self.encoder.decode(msg) + self.packets.emit(VoiceOPCode[data['op']], data['d']) except: self.log.exception('Failed to parse voice gateway message: ') - self.packets.emit(VoiceOPCode[data['op']], data['d']) - def on_error(self, err): # TODO self.log.warning('Voice websocket error: {}'.format(err)) @@ -205,6 +205,7 @@ class VoiceClient(LoggingClass): }) def on_close(self, code, error): + # TODO self.log.warning('Voice websocket disconnected (%s, %s)', code, error) if self.state == VoiceState.CONNECTED: diff --git a/disco/voice/opus.py b/disco/voice/opus.py index a1ad20e..b889cdf 100644 --- a/disco/voice/opus.py +++ b/disco/voice/opus.py @@ -1,8 +1,15 @@ import sys import array +import gevent import ctypes import ctypes.util +try: + from cStringIO import cStringIO as StringIO +except: + from StringIO import StringIO + +from gevent.queue import Queue from holster.enum import Enum from disco.util.logging import LoggingClass @@ -43,12 +50,12 @@ class BaseOpus(LoggingClass): for name, item in methods.items(): func = getattr(self.lib, name) - if item[1]: - func.argtypes = item[1] + if item[0]: + func.argtypes = item[0] - func.restype = item[2] + func.restype = item[1] - setattr(self, name.replace('opus_', ''), func) + setattr(self, name, func) @staticmethod def find_library(): @@ -83,7 +90,7 @@ class OpusEncoder(BaseOpus): } def __init__(self, sampling, channels, application=Application.AUDIO, library_path=None): - super(OpusDecoder, self).__init__(library_path) + super(OpusEncoder, self).__init__(library_path) self.sampling_rate = sampling self.channels = channels self.application = application @@ -94,10 +101,32 @@ class OpusEncoder(BaseOpus): self.frame_size = self.samples_per_frame * self.sample_size self.inst = self.create() + self.set_bitrate(128) + self.set_fec(True) + self.set_expected_packet_loss_percent(0.15) + + def set_bitrate(self, kbps): + kbps = min(128, max(16, int(kbps))) + ret = self.opus_encoder_ctl(self.inst, int(Control.SET_BITRATE), kbps * 1024) + + if ret < 0: + raise Exception('Failed to set bitrate to {}: {}'.format(kbps, ret)) + + def set_fec(self, value): + ret = self.opus_encoder_ctl(self.inst, int(Control.SET_FEC), int(value)) + + if ret < 0: + raise Exception('Failed to set FEC to {}: {}'.format(value, ret)) + + def set_expected_packet_loss_percent(self, perc): + ret = self.opus_encoder_ctl(self.inst, int(Control.SET_PLP), min(100, max(0, int(perc * 100)))) + + if ret < 0: + raise Exception('Failed to set PLP to {}: {}'.format(perc, ret)) def create(self): ret = ctypes.c_int() - result = self.encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret)) + result = self.opus_encoder_create(self.sampling_rate, self.channels, self.application.value, ctypes.byref(ret)) if ret.value != 0: raise Exception('Failed to create opus encoder: {}'.format(ret.value)) @@ -106,7 +135,7 @@ class OpusEncoder(BaseOpus): def __del__(self): if self.inst: - self.encoder_destroy(self.inst) + self.opus_encoder_destroy(self.inst) self.inst = None def encode(self, pcm, frame_size): @@ -114,12 +143,92 @@ class OpusEncoder(BaseOpus): pcm = ctypes.cast(pcm, c_int16_ptr) data = (ctypes.c_char * max_data_bytes)() - ret = self.encode(self.inst, pcm, frame_size, data, max_data_bytes) + ret = self.opus_encode(self.inst, pcm, frame_size, data, max_data_bytes) if ret < 0: raise Exception('Failed to encode: {}'.format(ret)) - return array.array('b', data[:ret]).tobytes() + # TODO: py3 + return array.array('b', data[:ret]).tostring() class OpusDecoder(BaseOpus): pass + + +class BufferedOpusEncoder(OpusEncoder): + def __init__(self, data, *args, **kwargs): + self.data = StringIO(data) + self.frames = Queue(kwargs.pop('queue_size', 4096)) + super(BufferedOpusEncoder, self).__init__(*args, **kwargs) + gevent.spawn(self._encoder_loop) + + def _encoder_loop(self): + while self.data: + raw = self.data.read(self.frame_size) + if len(raw) < self.frame_size: + break + + self.frames.put(self.encode(raw, self.samples_per_frame)) + gevent.idle() + self.data = None + + def have_frame(self): + return self.data or not self.frames.empty() + + def next_frame(self): + return self.frames.get() + + +class GIPCBufferedOpusEncoder(OpusEncoder): + FIN = 1 + + def __init__(self, data, *args, **kwargs): + import gipc + + self.data = StringIO(data) + self.parent_pipe, self.child_pipe = gipc.pipe(duplex=True) + self.frames = Queue(kwargs.pop('queue_size', 4096)) + super(GIPCBufferedOpusEncoder, self).__init__(*args, **kwargs) + + gipc.start_process(target=self._encoder_loop, args=(self.child_pipe, (args, kwargs))) + + gevent.spawn(self._writer) + gevent.spawn(self._reader) + + def _reader(self): + while True: + data = self.parent_pipe.get() + if data == self.FIN: + return + + self.frames.put(data) + self.parent_pipe = None + + def _writer(self): + while self.data: + raw = self.data.read(self.frame_size) + if len(raw) < self.frame_size: + break + + self.parent_pipe.put(raw) + gevent.idle() + + self.parent_pipe.put(self.FIN) + + def have_frame(self): + return self.parent_pipe + + def next_frame(self): + return self.frames.get() + + @classmethod + def _encoder_loop(cls, pipe, (args, kwargs)): + encoder = OpusEncoder(*args, **kwargs) + + while True: + data = pipe.get() + if data == cls.FIN: + pipe.put(cls.FIN) + return + + pipe.put(encoder.encode(data, encoder.samples_per_frame)) diff --git a/disco/voice/player.py b/disco/voice/player.py index 144a25b..353c0cf 100644 --- a/disco/voice/player.py +++ b/disco/voice/player.py @@ -1,18 +1,54 @@ +import time import gevent import struct -import time +import subprocess from six.moves import queue from disco.voice.client import VoiceState +from disco.voice.opus import BufferedOpusEncoder, GIPCBufferedOpusEncoder + + +class BaseFFmpegPlayable(object): + def __init__(self, source='-', command='avconv', sampling_rate=48000, channels=2, **kwargs): + args = [command, '-i', source, '-f', 's16le', '-ar', str(sampling_rate), '-ac', str(channels), '-loglevel', 'warning', 'pipe:1'] + self.proc = subprocess.Popen(args, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + data, _ = self.proc.communicate() + super(BaseFFmpegPlayable, self).__init__(data, sampling_rate, channels, **kwargs) + + +class FFmpegPlayable(BaseFFmpegPlayable, BufferedOpusEncoder): + pass + + +class GIPCFFmpegPlayable(BaseFFmpegPlayable, GIPCBufferedOpusEncoder): + pass -class OpusItem(object): - def __init__(self, frame_length=20, channels=2): +def create_youtube_dl_playable(url, cls=FFmpegPlayable, *args, **kwargs): + import youtube_dl + + ydl = youtube_dl.YoutubeDL({'format': 'webm[abr>0]/bestaudio/best'}) + info = ydl.extract_info(url, download=False) + + if 'entries' in info: + info = info['entries'][0] + + return cls(info['url'], *args, **kwargs), info + + +class OpusPlayable(object): + """ + Represents a Playable item which is a cached set of Opus-encoded bytes. + """ + def __init__(self, sampling_rate=48000, frame_length=20, channels=2): self.frames = [] self.idx = 0 + self.frame_length = 20 + self.sampling_rate = sampling_rate self.frame_length = frame_length self.channels = channels + self.sample_size = int(self.sampling_rate / 1000 * self.frame_length) @classmethod def from_raw_file(cls, path): @@ -58,6 +94,7 @@ class Player(object): def play(self, item): start = time.time() loops = 0 + timestamp = 0 while True: loops += 1 @@ -76,13 +113,15 @@ class Player(object): if not item.have_frame(): return - self.client.send_frame(item.next_frame()) + self.client.send_frame(item.next_frame(), loops, timestamp) + timestamp += item.samples_per_frame next_time = start + 0.02 * loops delay = max(0, 0.02 + (next_time - time.time())) gevent.sleep(delay) def run(self): self.client.set_speaking(True) + while self.playing: self.play(self.queue.get()) @@ -90,4 +129,6 @@ class Player(object): self.playing = False self.complete.set() return + self.client.set_speaking(False) + self.disconnect() diff --git a/examples/music.py b/examples/music.py index ce0d4a9..0fbb4f2 100644 --- a/examples/music.py +++ b/examples/music.py @@ -1,15 +1,16 @@ from disco.bot import Plugin from disco.bot.command import CommandError -from disco.voice.client import Player, OpusItem, VoiceException +from disco.voice.player import Player, create_youtube_dl_playable +from disco.voice.client import VoiceException def download(url): - return OpusItem.from_raw_file('test.dca') + return create_youtube_dl_playable(url)[0] class MusicPlugin(Plugin): - def load(self): - super(MusicPlugin, self).load() + def load(self, ctx): + super(MusicPlugin, self).load(ctx) self.guilds = {} @Plugin.command('join') diff --git a/requirements.txt b/requirements.txt index 24c51ff..3b0894b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ gevent==1.1.2 -holster==1.0.7 +holster==1.0.11 inflection==0.3.1 requests==2.11.1 six==1.10.0