diff --git a/disco/bot/__init__.py b/disco/bot/__init__.py index 135ac2b..060d02b 100644 --- a/disco/bot/__init__.py +++ b/disco/bot/__init__.py @@ -1,5 +1,6 @@ from disco.bot.bot import Bot, BotConfig from disco.bot.plugin import Plugin +from disco.bot.command import CommandLevels from disco.util.config import Config -__all__ = ['Bot', 'BotConfig', 'Plugin', 'Config'] +__all__ = ['Bot', 'BotConfig', 'Plugin', 'Config', 'CommandLevels'] diff --git a/disco/bot/backends/memory.py b/disco/bot/backends/memory.py index 5db93b9..8e26ca2 100644 --- a/disco/bot/backends/memory.py +++ b/disco/bot/backends/memory.py @@ -2,17 +2,6 @@ from .base import BaseStorageBackend, StorageDict class MemoryBackend(BaseStorageBackend): - def __init__(self): + def __init__(self, config): self.storage = StorageDict() - def base(self): - return self.storage - - def __getitem__(self, key): - return self.storage[key] - - def __setitem__(self, key, value): - self.storage[key] = value - - def __delitem__(self, key): - del self.storage[key] diff --git a/disco/bot/bot.py b/disco/bot/bot.py index ef49fb5..4c6846f 100644 --- a/disco/bot/bot.py +++ b/disco/bot/bot.py @@ -6,8 +6,9 @@ import inspect from six.moves import reload_module from holster.threadlocal import ThreadLocal +from disco.types.guild import GuildMember from disco.bot.plugin import Plugin -from disco.bot.command import CommandEvent +from disco.bot.command import CommandEvent, CommandLevels from disco.bot.storage import Storage from disco.util.config import Config from disco.util.serializer import Serializer @@ -20,9 +21,11 @@ class BotConfig(Config): Attributes ---------- - token : str - The authentication token for this bot. This is passed on to the - :class:`disco.client.Client` without any validation. + levels : dict(snowflake, str) + Mapping of user IDs/role IDs to :class:`disco.bot.commands.CommandLevesls` + which is used for the default commands_level_getter. + plugins : list[string] + List of plugin modules to load. commands_enabled : bool Whether this bot instance should utilize command parsing. Generally this should be true, unless your bot is only handling events and has no user @@ -42,17 +45,21 @@ class BotConfig(Config): If true, the bot will reparse an edited message if it was the last sent message in a channel, and did not previously trigger a command. This is helpful for allowing edits to typod commands. + commands_level_getter : function + If set, a function which when given a GuildMember or User, returns the + relevant :class:`disco.bot.commands.CommandLevels`. plugin_config_provider : Optional[function] If set, this function will replace the default configuration loading function, which normally attempts to load a file located at config/plugin_name.fmt where fmt is the plugin_config_format. The function here should return a valid configuration object which the plugin understands. plugin_config_format : str - The serilization format plugin configuration files are in. + The serialization format plugin configuration files are in. plugin_config_dir : str The directory plugin configuration is located within. """ - token = None + levels = {} + plugins = {} commands_enabled = True commands_require_mention = True @@ -64,6 +71,7 @@ class BotConfig(Config): } commands_prefix = '' commands_allow_edit = True + commands_level_getter = None plugin_config_provider = None plugin_config_format = 'yaml' @@ -127,6 +135,14 @@ class Bot(object): # Stores a giant regex matcher for all commands self.command_matches_re = None + # Finally, load all the plugin modules that where passed with the config + for plugin_mod in self.config.plugins: + self.add_plugin_module(plugin_mod) + + # Convert level mapping + for k, v in self.config.levels.items(): + self.config.levels[k] = CommandLevels.get(v) + @classmethod def from_cli(cls, *plugins): """ @@ -225,6 +241,32 @@ class Bot(object): if match: yield (command, match) + def get_level(self, actor): + level = CommandLevels.DEFAULT + + if callable(self.config.commands_level_getter): + level = self.config.commands_level_getter(actor) + else: + if actor.id in self.config.levels: + level = self.config.levels[actor.id] + + if isinstance(actor, GuildMember): + for rid in actor.roles: + if rid in self.config.levels and self.config.levels[rid] > level: + level = self.config.levels[rid] + + return level + + def check_command_permissions(self, command, msg): + if not command.level: + return True + + level = self.get_level(msg.author if not msg.guild else msg.guild.get_member(msg.author)) + + if level >= command.level: + return True + return False + def handle_message(self, msg): """ Attempts to handle a newly created or edited message in the context of @@ -243,10 +285,14 @@ class Bot(object): commands = list(self.get_commands_for_message(msg)) if len(commands): - return any([ - command.plugin.execute(CommandEvent(command, msg, match)) - for command, match in commands - ]) + result = False + for command, match in commands: + if not self.check_command_permissions(command, msg): + continue + + if command.plugin.execute(CommandEvent(command, msg, match)): + result = True + return result return False diff --git a/disco/bot/command.py b/disco/bot/command.py index 4fbd9e1..53dd672 100644 --- a/disco/bot/command.py +++ b/disco/bot/command.py @@ -1,5 +1,7 @@ import re +from holster.enum import Enum + from disco.bot.parser import ArgumentSet, ArgumentError from disco.util.functional import cached_property @@ -7,6 +9,14 @@ REGEX_FMT = '({})' ARGS_REGEX = '( (.*)$|$)' MENTION_RE = re.compile('<@!?([0-9]+)>') +CommandLevels = Enum( + DEFAULT=0, + TRUSTED=10, + MOD=50, + ADMIN=100, + OWNER=500, +) + class CommandEvent(object): """ @@ -33,7 +43,7 @@ class CommandEvent(object): self.msg = msg self.match = match self.name = self.match.group(1) - self.args = self.match.group(2).strip().split(' ') + self.args = [i for i in self.match.group(2).strip().split(' ') if i] @cached_property def member(self): @@ -93,7 +103,9 @@ class Command(object): is_regex : Optional[bool] Whether the triggers for this command should be treated as raw regex. """ - def __init__(self, plugin, func, trigger, args=None, aliases=None, group=None, is_regex=False): + def __init__(self, plugin, func, trigger, args=None, level=None, + aliases=None, group=None, is_regex=False): + self.plugin = plugin self.func = func self.triggers = [trigger] + (aliases or []) @@ -110,6 +122,7 @@ class Command(object): 'role': self.mention_type([resolve_role], force=True), }) + self.level = level self.group = group self.is_regex = is_regex diff --git a/disco/bot/plugin.py b/disco/bot/plugin.py index 87f0e0b..dd6c097 100644 --- a/disco/bot/plugin.py +++ b/disco/bot/plugin.py @@ -44,10 +44,23 @@ class PluginDeco(object): """ return cls.add_meta_deco({ 'type': 'listener', - 'event_name': event_name, + 'what': 'event', + 'desc': event_name, 'priority': priority }) + @classmethod + def listen_packet(cls, op, priority=None): + """ + Binds the function to listen for a given gateway op code + """ + return cls.add_meta_deco({ + 'type': 'listener', + 'what': 'packet', + 'desc': op, + 'priority': priority, + }) + @classmethod def command(cls, *args, **kwargs): """ @@ -155,7 +168,7 @@ class Plugin(LoggingClass, PluginDeco): if hasattr(member, 'meta'): for meta in member.meta: if meta['type'] == 'listener': - self.register_listener(member, meta['event_name'], meta['priority']) + self.register_listener(member, meta['what'], meta['desc'], meta['priority']) elif meta['type'] == 'command': self.register_command(member, *meta['args'], **meta['kwargs']) elif meta['type'] == 'schedule': @@ -205,21 +218,33 @@ class Plugin(LoggingClass, PluginDeco): return True - def register_listener(self, func, name, priority): + def register_listener(self, func, what, desc, priority): """ Registers a listener Parameters ---------- + what : str + What the listener is for (event, packet) func : function The function to be registered. - name : string - Name of event to listen for. + desc + The descriptor of the event/packet. priority : Priority The priority of this listener. """ func = functools.partial(self._dispatch, 'listener', func) - self.listeners.append(self.bot.client.events.on(name, func, priority=priority or Priority.NONE)) + + priority = priority or Priority.NONE + + if what == 'event': + li = self.bot.client.events.on(desc, func, priority=priority) + elif what == 'packet': + li = self.bot.client.packets.on(desc, func, priority=priority) + else: + raise Exception('Invalid listener what: {}'.format(what)) + + self.listeners.append(li) def register_command(self, func, *args, **kwargs): """ diff --git a/disco/cli.py b/disco/cli.py index 0c11d52..bb2e357 100644 --- a/disco/cli.py +++ b/disco/cli.py @@ -4,6 +4,7 @@ creating and running bots/clients. """ from __future__ import print_function +import os import logging import argparse @@ -12,13 +13,14 @@ from gevent import monkey monkey.patch_all() parser = argparse.ArgumentParser() -parser.add_argument('--token', help='Bot Authentication Token', required=True) -parser.add_argument('--shard-count', help='Total number of shards', default=1) -parser.add_argument('--shard-id', help='Current shard number/id', default=0) -parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=False) -parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default='localhost:8484') -parser.add_argument('--encoder', help='encoder for gateway data', default='json') -parser.add_argument('--bot', help='run a disco bot on this client', action='store_true', default=False) +parser.add_argument('--config', help='Configuration file', default='config.yaml') +parser.add_argument('--token', help='Bot Authentication Token', default=None) +parser.add_argument('--shard-count', help='Total number of shards', default=None) +parser.add_argument('--shard-id', help='Current shard number/id', default=None) +parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None) +parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None) +parser.add_argument('--encoder', help='encoder for gateway data', default=None) +parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False) parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[]) logging.basicConfig(level=logging.INFO) @@ -37,34 +39,34 @@ def disco_main(run=False): args = parser.parse_args() from disco.client import Client, ClientConfig - from disco.bot import Bot - from disco.gateway.encoding import ENCODERS + from disco.bot import Bot, BotConfig from disco.util.token import is_valid_token - if not is_valid_token(args.token): - print('Invalid token passed') - return + if os.path.exists(args.config): + config = ClientConfig.from_file(args.config) + else: + config = ClientConfig() - cfg = ClientConfig() - cfg.token = args.token - cfg.shard_id = args.shard_id - cfg.shard_count = args.shard_count - cfg.manhole_enable = args.manhole - cfg.manhole_bind = args.manhole_bind - cfg.encoding_cls = ENCODERS[args.encoder] + for k, v in vars(args).items(): + if hasattr(config, k) and v is not None: + setattr(config, k, v) - client = Client(cfg) + if not is_valid_token(config.token): + print('Invalid token passed') + return - if args.bot: - bot = Bot(client) + client = Client(config) - for plugin in args.plugin: - bot.add_plugin_module(plugin) + bot = None + if args.run_bot or hasattr(config, 'bot'): + bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig() + bot_config.plugins += args.plugin + bot = Bot(client, bot_config) if run: - client.run_forever() + (bot or client).run_forever() - return client + return (bot or client) if __name__ == '__main__': disco_main(True) diff --git a/disco/client.py b/disco/client.py index cfadc80..4d3bf54 100644 --- a/disco/client.py +++ b/disco/client.py @@ -5,11 +5,12 @@ from holster.emitter import Emitter from disco.state import State from disco.api.client import APIClient from disco.gateway.client import GatewayClient +from disco.util.config import Config from disco.util.logging import LoggingClass from disco.util.backdoor import DiscoBackdoorServer -class ClientConfig(LoggingClass): +class ClientConfig(LoggingClass, Config): """ Configuration for the :class:`Client`. @@ -27,8 +28,9 @@ class ClientConfig(LoggingClass): manhole_bind : tuple(str, int) A (host, port) combination which the manhole server will bind to (if its enabled using :attr:`manhole_enable`). - encoding_cls : class - The class to use for encoding/decoding data from websockets. + encoder : str + The type of encoding to use for encoding/decoding data from websockets, + should be either 'json' or 'etf'. """ token = "" @@ -38,7 +40,7 @@ class ClientConfig(LoggingClass): manhole_enable = True manhole_bind = ('127.0.0.1', 8484) - encoding_cls = None + encoder = 'json' class Client(object): @@ -82,7 +84,7 @@ class Client(object): self.state = State(self) self.api = APIClient(self) - self.gw = GatewayClient(self, self.config.encoding_cls) + self.gw = GatewayClient(self, self.config.encoder) if self.config.manhole_enable: self.manhole_locals = { diff --git a/disco/gateway/client.py b/disco/gateway/client.py index 87194c4..6c88d9f 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -3,9 +3,9 @@ import zlib import six import ssl -from disco.gateway.packets import OPCode +from disco.gateway.packets import OPCode, RECV, SEND from disco.gateway.events import GatewayEvent -from disco.gateway.encoding.json import JSONEncoder +from disco.gateway.encoding import ENCODERS from disco.util.websocket import Websocket from disco.util.logging import LoggingClass @@ -16,20 +16,20 @@ class GatewayClient(LoggingClass): GATEWAY_VERSION = 6 MAX_RECONNECTS = 5 - def __init__(self, client, encoder=None): + def __init__(self, client, encoder='json'): super(GatewayClient, self).__init__() self.client = client - self.encoder = encoder or JSONEncoder + self.encoder = ENCODERS[encoder] self.events = client.events self.packets = client.packets # Create emitter and bind to gateway payloads - self.packets.on(OPCode.DISPATCH, self.handle_dispatch) - self.packets.on(OPCode.HEARTBEAT, self.handle_heartbeat) - self.packets.on(OPCode.RECONNECT, self.handle_reconnect) - self.packets.on(OPCode.INVALID_SESSION, self.handle_invalid_session) - self.packets.on(OPCode.HELLO, self.handle_hello) + self.packets.on((RECV, OPCode.DISPATCH), self.handle_dispatch) + self.packets.on((RECV, OPCode.HEARTBEAT), self.handle_heartbeat) + self.packets.on((RECV, OPCode.RECONNECT), self.handle_reconnect) + self.packets.on((RECV, OPCode.INVALID_SESSION), self.handle_invalid_session) + self.packets.on((RECV, OPCode.HELLO), self.handle_hello) # Bind to ready payload self.events.on('Ready', self.on_ready) @@ -50,6 +50,7 @@ class GatewayClient(LoggingClass): self._heartbeat_task = None def send(self, op, data): + self.packets.emit((SEND, op), data) self.ws.send(self.encoder.encode({ 'op': op.value, 'd': data, @@ -119,7 +120,7 @@ class GatewayClient(LoggingClass): self.seq = data['s'] # Emit packet - self.packets.emit(OPCode[data['op']], data) + self.packets.emit((RECV, OPCode[data['op']]), data) def on_error(self, error): if isinstance(error, KeyboardInterrupt): diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 3da9173..99f6d9e 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -1,3 +1,5 @@ +from __future__ import print_function + import inflection import six @@ -49,6 +51,23 @@ class GatewayEvent(Model): raise AttributeError(name) +def debug(func=None): + def deco(cls): + old_init = cls.__init__ + + def new_init(self, obj, *args, **kwargs): + if func: + print(func(obj)) + else: + print(obj) + + old_init(self, obj, *args, **kwargs) + + cls.__init__ = new_init + return cls + return deco + + def wraps_model(model, alias=None): alias = alias or model.__name__.lower() diff --git a/disco/gateway/packets.py b/disco/gateway/packets.py index 5ca9793..e78c1ce 100644 --- a/disco/gateway/packets.py +++ b/disco/gateway/packets.py @@ -1,5 +1,8 @@ from holster.enum import Enum +SEND = object() +RECV = object() + OPCode = Enum( DISPATCH=0, HEARTBEAT=1, diff --git a/disco/state.py b/disco/state.py index 08ed556..0b30903 100644 --- a/disco/state.py +++ b/disco/state.py @@ -79,7 +79,7 @@ class State(object): EVENTS = [ 'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove', 'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete', - 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate' + 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate', ] def __init__(self, client, config=None): @@ -96,7 +96,7 @@ class State(object): # If message tracking is enabled, listen to those events if self.config.track_messages: self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) - self.EVENTS += ['MessageCreate', 'MessageDelete'] + self.EVENTS += ['MessageDelete'] # The bound listener objects self.listeners = [] @@ -120,25 +120,21 @@ class State(object): func = 'on_' + inflection.underscore(event) self.listeners.append(self.client.events.on(event, getattr(self, func))) + def fill_messages(self, channel): + for message in reversed(next(channel.messages_iter(bulk=True))): + self.messages[channel.id].append( + StackMessage(message.id, message.channel_id, message.author.id)) + def on_ready(self, event): self.me = event.user def on_message_create(self, event): - self.messages[event.message.channel_id].append( - StackMessage(event.message.id, event.message.channel_id, event.message.author.id)) - - def on_message_update(self, event): - message, cid = event.message, event.message.channel_id - if cid not in self.messages: - return - - sm = next((i for i in self.messages[cid] if i.id == message.id), None) - if not sm: - return + if self.config.track_messages: + self.messages[event.message.channel_id].append( + StackMessage(event.message.id, event.message.channel_id, event.message.author.id)) - sm.id = message.id - sm.channel_id = cid - sm.author_id = message.author.id + if event.message.channel_id in self.channels: + self.channels[event.message.channel_id].last_message_id = event.message.id def on_message_delete(self, event): if event.channel_id not in self.messages: @@ -150,6 +146,14 @@ class State(object): self.messages[event.channel_id].remove(sm) + def on_message_delete_bulk(self, event): + if event.channel_id not in self.messages: + return + + for sm in self.messages[event.channel_id]: + if sm.id in event.ids: + self.messages[event.channel_id].remove(sm) + def on_guild_create(self, event): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) diff --git a/disco/types/base.py b/disco/types/base.py index cb95f79..f91ea07 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -124,8 +124,18 @@ def binary(obj): return bytes(obj) -def field(typ, alias=None): - pass +def with_equality(field): + class T(object): + def __eq__(self, other): + return getattr(self, field) == getattr(other, field) + return T + + +def with_hash(field): + class T(object): + def __hash__(self, other): + return hash(getattr(self, field)) + return T class ModelMeta(type): @@ -156,7 +166,7 @@ class Model(six.with_metaclass(ModelMeta)): obj = kwargs for name, field in self._fields.items(): - if name not in obj or not obj[field.src_name]: + if field.src_name not in obj or not obj[field.src_name]: if field.has_default(): setattr(self, field.dst_name, field.default()) continue diff --git a/disco/types/channel.py b/disco/types/channel.py index 92d3b6a..a54b4bf 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -4,7 +4,7 @@ from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text from disco.types.permissions import PermissionValue from disco.util import to_snowflake -from disco.util.functional import cached_property +from disco.util.functional import cached_property, one_or_many, chunks from disco.types.user import User from disco.types.permissions import Permissions, Permissible from disco.voice.client import VoiceClient @@ -81,7 +81,7 @@ class Channel(Model, Permissible): guild_id = Field(snowflake) name = Field(text) topic = Field(text) - _last_message_id = Field(snowflake, alias='last_message_id') + last_message_id = Field(snowflake) position = Field(int) bitrate = Field(int) recipients = Field(listof(User)) @@ -138,15 +138,6 @@ class Channel(Model, Permissible): """ return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM) - @property - def last_message_id(self): - """ - Returns the ID of the last message sent in this channel - """ - if self.id not in self.client.state.messages: - return self._last_message_id - return self.client.state.messages[self.id][-1].id - @property def messages(self): """ @@ -159,7 +150,7 @@ class Channel(Model, Permissible): Creates a new :class:`MessageIterator` for the channel with the given keyword arguments """ - return MessageIterator(self.client, self.id, before=self.last_message_id, **kwargs) + return MessageIterator(self.client, self.id, **kwargs) @cached_property def guild(self): @@ -242,9 +233,40 @@ class Channel(Model, Permissible): def delete_overwrite(self, ow): self.client.api.channels_permissions_delete(self.id, ow.id) - def delete_messages_bulk(self, messages): + def delete_message(self, message): + """ + Deletes a single message from this channel. + + Args + ---- + message : snowflake|:class:`disco.types.message.Message` + The message to delete. + """ + self.client.api.channels_messages_delete(self.id, to_snowflake(message)) + + @one_or_many + def delete_messages(self, messages): + """ + Deletes a set of messages using the correct API route based on the number + of messages passed. + + Args + ---- + messages : list[snowflake|:class:`disco.types.message.Message`] + List of messages (or message ids) to delete. All messages must originate + from this channel. + """ messages = map(to_snowflake, messages) - self.client.api.channels_messages_delete_bulk(self.id, messages) + + if not messages: + return + + if len(messages) <= 2: + for msg in messages: + self.delete_message(msg) + else: + for chunk in chunks(messages, 100): + self.client.api.channels_messages_delete_bulk(self.id, chunk) class MessageIterator(object): @@ -283,9 +305,6 @@ class MessageIterator(object): self.last = None self._buffer = [] - if not before and not after: - raise Exception('Must specify at most one of before or after') - if not any((before, after)) and self.direction == self.Direction.DOWN: raise Exception('Must specify either before or after for downward seeking') diff --git a/disco/types/user.py b/disco/types/user.py index ccd2e5b..a5952bc 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -1,7 +1,7 @@ -from disco.types.base import Model, Field, snowflake, text, binary +from disco.types.base import Model, Field, snowflake, text, binary, with_equality, with_hash -class User(Model): +class User(Model, with_equality('id'), with_hash('id')): id = Field(snowflake) username = Field(text) discriminator = Field(str) diff --git a/disco/util/config.py b/disco/util/config.py index 74927c8..de98947 100644 --- a/disco/util/config.py +++ b/disco/util/config.py @@ -21,8 +21,8 @@ class Config(object): data = f.read() _, ext = os.path.splitext(path) - Serializer.check_format(ext) - inst.__dict__.update(Serializer.load(ext, data)) + Serializer.check_format(ext[1:]) + inst.__dict__.update(Serializer.loads(ext[1:], data)) return inst def from_prefix(self, prefix): @@ -33,10 +33,13 @@ class Config(object): if k.startswith(prefix): obj[k[len(prefix):]] = v - return obj + return Config(obj) def update(self, other): if isinstance(other, Config): other = other.__dict__ self.__dict__.update(other) + + def to_dict(self): + return self.__dict__ diff --git a/disco/util/functional.py b/disco/util/functional.py index 951a1e4..8641457 100644 --- a/disco/util/functional.py +++ b/disco/util/functional.py @@ -1,5 +1,54 @@ from gevent.lock import RLock +from six.moves import range + +NO_MORE_SENTINEL = object() + + +def take(seq, count): + """ + Take count many elements from a sequence or generator. + + Args + ---- + seq : sequnce or generator + The sequnce to take elements from. + count : int + The number of elments to take. + """ + for _ in range(count): + i = next(seq, NO_MORE_SENTINEL) + if i is NO_MORE_SENTINEL: + raise StopIteration + yield i + + +def chunks(obj, size): + """ + Splits a list into sized chunks. + + Args + ---- + obj : list + List to split up. + size : int + Size of chunks to split list into. + """ + for i in range(0, len(obj), size): + yield obj[i:i + size] + + +def one_or_many(f): + """ + Wraps a function so that it will either take a single argument, or a variable + number of args. + """ + def _f(*args): + if len(args) == 1: + return f(args[0]) + return f(*args) + return _f + def cached_property(f): """ diff --git a/disco/util/snowflake.py b/disco/util/snowflake.py new file mode 100644 index 0000000..bbca39d --- /dev/null +++ b/disco/util/snowflake.py @@ -0,0 +1,18 @@ +from datetime import datetime + +DISCORD_EPOCH = 1420070400000 + + +def to_datetime(snowflake): + """ + Converts a snowflake to a UTC datetime. + """ + return datetime.utcfromtimestamp(to_unix(snowflake)) + + +def to_unix(snowflake): + return to_unix_ms(snowflake) / 1000 + + +def to_unix_ms(snowflake): + return ((int(snowflake) >> 22) + DISCORD_EPOCH) diff --git a/requirements.txt b/requirements.txt index a8586e7..9051886 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ gevent==1.1.2 -holster==1.0.5 +holster==1.0.6 inflection==0.3.1 requests==2.11.1 six==1.10.0