Browse Source

no seriously make small commits

- Storage backends take a config
- Add command permissions
- Add ability to listen to BOTH incoming and outgoing gateway packets
- Heavily refactor cli, now prefer loading from config with options as
overrides
- Add debug function to gateway events, helps with figuring data out w/o
spamming console
- Change Channel.last_message_id from property that looks in the message
tracking deque, to a attribute that gets updated by the state module
- Add State.fill_messages for backfilling the messages store
- Handle MessageDeleteBulk in State
- Add some helper functions for hash/equality model functions
- Fix MessageIterator
- Add Channel.delete_message, Channel.delete_messages
- Some more functional stuff
- Snowflake timestamp conversion
- Bump holster
pull/5/head
Andrei 9 years ago
parent
commit
6036dd8150
  1. 3
      disco/bot/__init__.py
  2. 13
      disco/bot/backends/memory.py
  3. 66
      disco/bot/bot.py
  4. 17
      disco/bot/command.py
  5. 37
      disco/bot/plugin.py
  6. 54
      disco/cli.py
  7. 12
      disco/client.py
  8. 21
      disco/gateway/client.py
  9. 19
      disco/gateway/events.py
  10. 3
      disco/gateway/packets.py
  11. 32
      disco/state.py
  12. 16
      disco/types/base.py
  13. 53
      disco/types/channel.py
  14. 4
      disco/types/user.py
  15. 9
      disco/util/config.py
  16. 49
      disco/util/functional.py
  17. 18
      disco/util/snowflake.py
  18. 2
      requirements.txt

3
disco/bot/__init__.py

@ -1,5 +1,6 @@
from disco.bot.bot import Bot, BotConfig from disco.bot.bot import Bot, BotConfig
from disco.bot.plugin import Plugin from disco.bot.plugin import Plugin
from disco.bot.command import CommandLevels
from disco.util.config import Config from disco.util.config import Config
__all__ = ['Bot', 'BotConfig', 'Plugin', 'Config'] __all__ = ['Bot', 'BotConfig', 'Plugin', 'Config', 'CommandLevels']

13
disco/bot/backends/memory.py

@ -2,17 +2,6 @@ from .base import BaseStorageBackend, StorageDict
class MemoryBackend(BaseStorageBackend): class MemoryBackend(BaseStorageBackend):
def __init__(self): def __init__(self, config):
self.storage = StorageDict() self.storage = StorageDict()
def base(self):
return self.storage
def __getitem__(self, key):
return self.storage[key]
def __setitem__(self, key, value):
self.storage[key] = value
def __delitem__(self, key):
del self.storage[key]

66
disco/bot/bot.py

@ -6,8 +6,9 @@ import inspect
from six.moves import reload_module from six.moves import reload_module
from holster.threadlocal import ThreadLocal from holster.threadlocal import ThreadLocal
from disco.types.guild import GuildMember
from disco.bot.plugin import Plugin from disco.bot.plugin import Plugin
from disco.bot.command import CommandEvent from disco.bot.command import CommandEvent, CommandLevels
from disco.bot.storage import Storage from disco.bot.storage import Storage
from disco.util.config import Config from disco.util.config import Config
from disco.util.serializer import Serializer from disco.util.serializer import Serializer
@ -20,9 +21,11 @@ class BotConfig(Config):
Attributes Attributes
---------- ----------
token : str levels : dict(snowflake, str)
The authentication token for this bot. This is passed on to the Mapping of user IDs/role IDs to :class:`disco.bot.commands.CommandLevesls`
:class:`disco.client.Client` without any validation. which is used for the default commands_level_getter.
plugins : list[string]
List of plugin modules to load.
commands_enabled : bool commands_enabled : bool
Whether this bot instance should utilize command parsing. Generally this Whether this bot instance should utilize command parsing. Generally this
should be true, unless your bot is only handling events and has no user should be true, unless your bot is only handling events and has no user
@ -42,17 +45,21 @@ class BotConfig(Config):
If true, the bot will reparse an edited message if it was the last sent If true, the bot will reparse an edited message if it was the last sent
message in a channel, and did not previously trigger a command. This is message in a channel, and did not previously trigger a command. This is
helpful for allowing edits to typod commands. helpful for allowing edits to typod commands.
commands_level_getter : function
If set, a function which when given a GuildMember or User, returns the
relevant :class:`disco.bot.commands.CommandLevels`.
plugin_config_provider : Optional[function] plugin_config_provider : Optional[function]
If set, this function will replace the default configuration loading If set, this function will replace the default configuration loading
function, which normally attempts to load a file located at config/plugin_name.fmt function, which normally attempts to load a file located at config/plugin_name.fmt
where fmt is the plugin_config_format. The function here should return where fmt is the plugin_config_format. The function here should return
a valid configuration object which the plugin understands. a valid configuration object which the plugin understands.
plugin_config_format : str plugin_config_format : str
The serilization format plugin configuration files are in. The serialization format plugin configuration files are in.
plugin_config_dir : str plugin_config_dir : str
The directory plugin configuration is located within. The directory plugin configuration is located within.
""" """
token = None levels = {}
plugins = {}
commands_enabled = True commands_enabled = True
commands_require_mention = True commands_require_mention = True
@ -64,6 +71,7 @@ class BotConfig(Config):
} }
commands_prefix = '' commands_prefix = ''
commands_allow_edit = True commands_allow_edit = True
commands_level_getter = None
plugin_config_provider = None plugin_config_provider = None
plugin_config_format = 'yaml' plugin_config_format = 'yaml'
@ -127,6 +135,14 @@ class Bot(object):
# Stores a giant regex matcher for all commands # Stores a giant regex matcher for all commands
self.command_matches_re = None self.command_matches_re = None
# Finally, load all the plugin modules that where passed with the config
for plugin_mod in self.config.plugins:
self.add_plugin_module(plugin_mod)
# Convert level mapping
for k, v in self.config.levels.items():
self.config.levels[k] = CommandLevels.get(v)
@classmethod @classmethod
def from_cli(cls, *plugins): def from_cli(cls, *plugins):
""" """
@ -225,6 +241,32 @@ class Bot(object):
if match: if match:
yield (command, match) yield (command, match)
def get_level(self, actor):
level = CommandLevels.DEFAULT
if callable(self.config.commands_level_getter):
level = self.config.commands_level_getter(actor)
else:
if actor.id in self.config.levels:
level = self.config.levels[actor.id]
if isinstance(actor, GuildMember):
for rid in actor.roles:
if rid in self.config.levels and self.config.levels[rid] > level:
level = self.config.levels[rid]
return level
def check_command_permissions(self, command, msg):
if not command.level:
return True
level = self.get_level(msg.author if not msg.guild else msg.guild.get_member(msg.author))
if level >= command.level:
return True
return False
def handle_message(self, msg): def handle_message(self, msg):
""" """
Attempts to handle a newly created or edited message in the context of Attempts to handle a newly created or edited message in the context of
@ -243,10 +285,14 @@ class Bot(object):
commands = list(self.get_commands_for_message(msg)) commands = list(self.get_commands_for_message(msg))
if len(commands): if len(commands):
return any([ result = False
command.plugin.execute(CommandEvent(command, msg, match)) for command, match in commands:
for command, match in commands if not self.check_command_permissions(command, msg):
]) continue
if command.plugin.execute(CommandEvent(command, msg, match)):
result = True
return result
return False return False

17
disco/bot/command.py

@ -1,5 +1,7 @@
import re import re
from holster.enum import Enum
from disco.bot.parser import ArgumentSet, ArgumentError from disco.bot.parser import ArgumentSet, ArgumentError
from disco.util.functional import cached_property from disco.util.functional import cached_property
@ -7,6 +9,14 @@ REGEX_FMT = '({})'
ARGS_REGEX = '( (.*)$|$)' ARGS_REGEX = '( (.*)$|$)'
MENTION_RE = re.compile('<@!?([0-9]+)>') MENTION_RE = re.compile('<@!?([0-9]+)>')
CommandLevels = Enum(
DEFAULT=0,
TRUSTED=10,
MOD=50,
ADMIN=100,
OWNER=500,
)
class CommandEvent(object): class CommandEvent(object):
""" """
@ -33,7 +43,7 @@ class CommandEvent(object):
self.msg = msg self.msg = msg
self.match = match self.match = match
self.name = self.match.group(1) self.name = self.match.group(1)
self.args = self.match.group(2).strip().split(' ') self.args = [i for i in self.match.group(2).strip().split(' ') if i]
@cached_property @cached_property
def member(self): def member(self):
@ -93,7 +103,9 @@ class Command(object):
is_regex : Optional[bool] is_regex : Optional[bool]
Whether the triggers for this command should be treated as raw regex. Whether the triggers for this command should be treated as raw regex.
""" """
def __init__(self, plugin, func, trigger, args=None, aliases=None, group=None, is_regex=False): def __init__(self, plugin, func, trigger, args=None, level=None,
aliases=None, group=None, is_regex=False):
self.plugin = plugin self.plugin = plugin
self.func = func self.func = func
self.triggers = [trigger] + (aliases or []) self.triggers = [trigger] + (aliases or [])
@ -110,6 +122,7 @@ class Command(object):
'role': self.mention_type([resolve_role], force=True), 'role': self.mention_type([resolve_role], force=True),
}) })
self.level = level
self.group = group self.group = group
self.is_regex = is_regex self.is_regex = is_regex

37
disco/bot/plugin.py

@ -44,10 +44,23 @@ class PluginDeco(object):
""" """
return cls.add_meta_deco({ return cls.add_meta_deco({
'type': 'listener', 'type': 'listener',
'event_name': event_name, 'what': 'event',
'desc': event_name,
'priority': priority 'priority': priority
}) })
@classmethod
def listen_packet(cls, op, priority=None):
"""
Binds the function to listen for a given gateway op code
"""
return cls.add_meta_deco({
'type': 'listener',
'what': 'packet',
'desc': op,
'priority': priority,
})
@classmethod @classmethod
def command(cls, *args, **kwargs): def command(cls, *args, **kwargs):
""" """
@ -155,7 +168,7 @@ class Plugin(LoggingClass, PluginDeco):
if hasattr(member, 'meta'): if hasattr(member, 'meta'):
for meta in member.meta: for meta in member.meta:
if meta['type'] == 'listener': if meta['type'] == 'listener':
self.register_listener(member, meta['event_name'], meta['priority']) self.register_listener(member, meta['what'], meta['desc'], meta['priority'])
elif meta['type'] == 'command': elif meta['type'] == 'command':
self.register_command(member, *meta['args'], **meta['kwargs']) self.register_command(member, *meta['args'], **meta['kwargs'])
elif meta['type'] == 'schedule': elif meta['type'] == 'schedule':
@ -205,21 +218,33 @@ class Plugin(LoggingClass, PluginDeco):
return True return True
def register_listener(self, func, name, priority): def register_listener(self, func, what, desc, priority):
""" """
Registers a listener Registers a listener
Parameters Parameters
---------- ----------
what : str
What the listener is for (event, packet)
func : function func : function
The function to be registered. The function to be registered.
name : string desc
Name of event to listen for. The descriptor of the event/packet.
priority : Priority priority : Priority
The priority of this listener. The priority of this listener.
""" """
func = functools.partial(self._dispatch, 'listener', func) func = functools.partial(self._dispatch, 'listener', func)
self.listeners.append(self.bot.client.events.on(name, func, priority=priority or Priority.NONE))
priority = priority or Priority.NONE
if what == 'event':
li = self.bot.client.events.on(desc, func, priority=priority)
elif what == 'packet':
li = self.bot.client.packets.on(desc, func, priority=priority)
else:
raise Exception('Invalid listener what: {}'.format(what))
self.listeners.append(li)
def register_command(self, func, *args, **kwargs): def register_command(self, func, *args, **kwargs):
""" """

54
disco/cli.py

@ -4,6 +4,7 @@ creating and running bots/clients.
""" """
from __future__ import print_function from __future__ import print_function
import os
import logging import logging
import argparse import argparse
@ -12,13 +13,14 @@ from gevent import monkey
monkey.patch_all() monkey.patch_all()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--token', help='Bot Authentication Token', required=True) parser.add_argument('--config', help='Configuration file', default='config.yaml')
parser.add_argument('--shard-count', help='Total number of shards', default=1) parser.add_argument('--token', help='Bot Authentication Token', default=None)
parser.add_argument('--shard-id', help='Current shard number/id', default=0) parser.add_argument('--shard-count', help='Total number of shards', default=None)
parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=False) parser.add_argument('--shard-id', help='Current shard number/id', default=None)
parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default='localhost:8484') parser.add_argument('--manhole', action='store_true', help='Enable the manhole', default=None)
parser.add_argument('--encoder', help='encoder for gateway data', default='json') parser.add_argument('--manhole-bind', help='host:port for the manhole to bind too', default=None)
parser.add_argument('--bot', help='run a disco bot on this client', action='store_true', default=False) parser.add_argument('--encoder', help='encoder for gateway data', default=None)
parser.add_argument('--run-bot', help='run a disco bot on this client', action='store_true', default=False)
parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[]) parser.add_argument('--plugin', help='load plugins into the bot', nargs='*', default=[])
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -37,34 +39,34 @@ def disco_main(run=False):
args = parser.parse_args() args = parser.parse_args()
from disco.client import Client, ClientConfig from disco.client import Client, ClientConfig
from disco.bot import Bot from disco.bot import Bot, BotConfig
from disco.gateway.encoding import ENCODERS
from disco.util.token import is_valid_token from disco.util.token import is_valid_token
if not is_valid_token(args.token): if os.path.exists(args.config):
print('Invalid token passed') config = ClientConfig.from_file(args.config)
return else:
config = ClientConfig()
cfg = ClientConfig() for k, v in vars(args).items():
cfg.token = args.token if hasattr(config, k) and v is not None:
cfg.shard_id = args.shard_id setattr(config, k, v)
cfg.shard_count = args.shard_count
cfg.manhole_enable = args.manhole
cfg.manhole_bind = args.manhole_bind
cfg.encoding_cls = ENCODERS[args.encoder]
client = Client(cfg) if not is_valid_token(config.token):
print('Invalid token passed')
return
if args.bot: client = Client(config)
bot = Bot(client)
for plugin in args.plugin: bot = None
bot.add_plugin_module(plugin) if args.run_bot or hasattr(config, 'bot'):
bot_config = BotConfig(config.bot) if hasattr(config, 'bot') else BotConfig()
bot_config.plugins += args.plugin
bot = Bot(client, bot_config)
if run: if run:
client.run_forever() (bot or client).run_forever()
return client return (bot or client)
if __name__ == '__main__': if __name__ == '__main__':
disco_main(True) disco_main(True)

12
disco/client.py

@ -5,11 +5,12 @@ from holster.emitter import Emitter
from disco.state import State from disco.state import State
from disco.api.client import APIClient from disco.api.client import APIClient
from disco.gateway.client import GatewayClient from disco.gateway.client import GatewayClient
from disco.util.config import Config
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
from disco.util.backdoor import DiscoBackdoorServer from disco.util.backdoor import DiscoBackdoorServer
class ClientConfig(LoggingClass): class ClientConfig(LoggingClass, Config):
""" """
Configuration for the :class:`Client`. Configuration for the :class:`Client`.
@ -27,8 +28,9 @@ class ClientConfig(LoggingClass):
manhole_bind : tuple(str, int) manhole_bind : tuple(str, int)
A (host, port) combination which the manhole server will bind to (if its A (host, port) combination which the manhole server will bind to (if its
enabled using :attr:`manhole_enable`). enabled using :attr:`manhole_enable`).
encoding_cls : class encoder : str
The class to use for encoding/decoding data from websockets. The type of encoding to use for encoding/decoding data from websockets,
should be either 'json' or 'etf'.
""" """
token = "" token = ""
@ -38,7 +40,7 @@ class ClientConfig(LoggingClass):
manhole_enable = True manhole_enable = True
manhole_bind = ('127.0.0.1', 8484) manhole_bind = ('127.0.0.1', 8484)
encoding_cls = None encoder = 'json'
class Client(object): class Client(object):
@ -82,7 +84,7 @@ class Client(object):
self.state = State(self) self.state = State(self)
self.api = APIClient(self) self.api = APIClient(self)
self.gw = GatewayClient(self, self.config.encoding_cls) self.gw = GatewayClient(self, self.config.encoder)
if self.config.manhole_enable: if self.config.manhole_enable:
self.manhole_locals = { self.manhole_locals = {

21
disco/gateway/client.py

@ -3,9 +3,9 @@ import zlib
import six import six
import ssl import ssl
from disco.gateway.packets import OPCode from disco.gateway.packets import OPCode, RECV, SEND
from disco.gateway.events import GatewayEvent from disco.gateway.events import GatewayEvent
from disco.gateway.encoding.json import JSONEncoder from disco.gateway.encoding import ENCODERS
from disco.util.websocket import Websocket from disco.util.websocket import Websocket
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
@ -16,20 +16,20 @@ class GatewayClient(LoggingClass):
GATEWAY_VERSION = 6 GATEWAY_VERSION = 6
MAX_RECONNECTS = 5 MAX_RECONNECTS = 5
def __init__(self, client, encoder=None): def __init__(self, client, encoder='json'):
super(GatewayClient, self).__init__() super(GatewayClient, self).__init__()
self.client = client self.client = client
self.encoder = encoder or JSONEncoder self.encoder = ENCODERS[encoder]
self.events = client.events self.events = client.events
self.packets = client.packets self.packets = client.packets
# Create emitter and bind to gateway payloads # Create emitter and bind to gateway payloads
self.packets.on(OPCode.DISPATCH, self.handle_dispatch) self.packets.on((RECV, OPCode.DISPATCH), self.handle_dispatch)
self.packets.on(OPCode.HEARTBEAT, self.handle_heartbeat) self.packets.on((RECV, OPCode.HEARTBEAT), self.handle_heartbeat)
self.packets.on(OPCode.RECONNECT, self.handle_reconnect) self.packets.on((RECV, OPCode.RECONNECT), self.handle_reconnect)
self.packets.on(OPCode.INVALID_SESSION, self.handle_invalid_session) self.packets.on((RECV, OPCode.INVALID_SESSION), self.handle_invalid_session)
self.packets.on(OPCode.HELLO, self.handle_hello) self.packets.on((RECV, OPCode.HELLO), self.handle_hello)
# Bind to ready payload # Bind to ready payload
self.events.on('Ready', self.on_ready) self.events.on('Ready', self.on_ready)
@ -50,6 +50,7 @@ class GatewayClient(LoggingClass):
self._heartbeat_task = None self._heartbeat_task = None
def send(self, op, data): def send(self, op, data):
self.packets.emit((SEND, op), data)
self.ws.send(self.encoder.encode({ self.ws.send(self.encoder.encode({
'op': op.value, 'op': op.value,
'd': data, 'd': data,
@ -119,7 +120,7 @@ class GatewayClient(LoggingClass):
self.seq = data['s'] self.seq = data['s']
# Emit packet # Emit packet
self.packets.emit(OPCode[data['op']], data) self.packets.emit((RECV, OPCode[data['op']]), data)
def on_error(self, error): def on_error(self, error):
if isinstance(error, KeyboardInterrupt): if isinstance(error, KeyboardInterrupt):

19
disco/gateway/events.py

@ -1,3 +1,5 @@
from __future__ import print_function
import inflection import inflection
import six import six
@ -49,6 +51,23 @@ class GatewayEvent(Model):
raise AttributeError(name) raise AttributeError(name)
def debug(func=None):
def deco(cls):
old_init = cls.__init__
def new_init(self, obj, *args, **kwargs):
if func:
print(func(obj))
else:
print(obj)
old_init(self, obj, *args, **kwargs)
cls.__init__ = new_init
return cls
return deco
def wraps_model(model, alias=None): def wraps_model(model, alias=None):
alias = alias or model.__name__.lower() alias = alias or model.__name__.lower()

3
disco/gateway/packets.py

@ -1,5 +1,8 @@
from holster.enum import Enum from holster.enum import Enum
SEND = object()
RECV = object()
OPCode = Enum( OPCode = Enum(
DISPATCH=0, DISPATCH=0,
HEARTBEAT=1, HEARTBEAT=1,

32
disco/state.py

@ -79,7 +79,7 @@ class State(object):
EVENTS = [ EVENTS = [
'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove', 'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove',
'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete', 'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete',
'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate' 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceStateUpdate', 'MessageCreate',
] ]
def __init__(self, client, config=None): def __init__(self, client, config=None):
@ -96,7 +96,7 @@ class State(object):
# If message tracking is enabled, listen to those events # If message tracking is enabled, listen to those events
if self.config.track_messages: if self.config.track_messages:
self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size))
self.EVENTS += ['MessageCreate', 'MessageDelete'] self.EVENTS += ['MessageDelete']
# The bound listener objects # The bound listener objects
self.listeners = [] self.listeners = []
@ -120,25 +120,21 @@ class State(object):
func = 'on_' + inflection.underscore(event) func = 'on_' + inflection.underscore(event)
self.listeners.append(self.client.events.on(event, getattr(self, func))) self.listeners.append(self.client.events.on(event, getattr(self, func)))
def fill_messages(self, channel):
for message in reversed(next(channel.messages_iter(bulk=True))):
self.messages[channel.id].append(
StackMessage(message.id, message.channel_id, message.author.id))
def on_ready(self, event): def on_ready(self, event):
self.me = event.user self.me = event.user
def on_message_create(self, event): def on_message_create(self, event):
if self.config.track_messages:
self.messages[event.message.channel_id].append( self.messages[event.message.channel_id].append(
StackMessage(event.message.id, event.message.channel_id, event.message.author.id)) StackMessage(event.message.id, event.message.channel_id, event.message.author.id))
def on_message_update(self, event): if event.message.channel_id in self.channels:
message, cid = event.message, event.message.channel_id self.channels[event.message.channel_id].last_message_id = event.message.id
if cid not in self.messages:
return
sm = next((i for i in self.messages[cid] if i.id == message.id), None)
if not sm:
return
sm.id = message.id
sm.channel_id = cid
sm.author_id = message.author.id
def on_message_delete(self, event): def on_message_delete(self, event):
if event.channel_id not in self.messages: if event.channel_id not in self.messages:
@ -150,6 +146,14 @@ class State(object):
self.messages[event.channel_id].remove(sm) self.messages[event.channel_id].remove(sm)
def on_message_delete_bulk(self, event):
if event.channel_id not in self.messages:
return
for sm in self.messages[event.channel_id]:
if sm.id in event.ids:
self.messages[event.channel_id].remove(sm)
def on_guild_create(self, event): def on_guild_create(self, event):
self.guilds[event.guild.id] = event.guild self.guilds[event.guild.id] = event.guild
self.channels.update(event.guild.channels) self.channels.update(event.guild.channels)

16
disco/types/base.py

@ -124,8 +124,18 @@ def binary(obj):
return bytes(obj) return bytes(obj)
def field(typ, alias=None): def with_equality(field):
pass class T(object):
def __eq__(self, other):
return getattr(self, field) == getattr(other, field)
return T
def with_hash(field):
class T(object):
def __hash__(self, other):
return hash(getattr(self, field))
return T
class ModelMeta(type): class ModelMeta(type):
@ -156,7 +166,7 @@ class Model(six.with_metaclass(ModelMeta)):
obj = kwargs obj = kwargs
for name, field in self._fields.items(): for name, field in self._fields.items():
if name not in obj or not obj[field.src_name]: if field.src_name not in obj or not obj[field.src_name]:
if field.has_default(): if field.has_default():
setattr(self, field.dst_name, field.default()) setattr(self, field.dst_name, field.default())
continue continue

53
disco/types/channel.py

@ -4,7 +4,7 @@ from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text
from disco.types.permissions import PermissionValue from disco.types.permissions import PermissionValue
from disco.util import to_snowflake from disco.util import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property, one_or_many, chunks
from disco.types.user import User from disco.types.user import User
from disco.types.permissions import Permissions, Permissible from disco.types.permissions import Permissions, Permissible
from disco.voice.client import VoiceClient from disco.voice.client import VoiceClient
@ -81,7 +81,7 @@ class Channel(Model, Permissible):
guild_id = Field(snowflake) guild_id = Field(snowflake)
name = Field(text) name = Field(text)
topic = Field(text) topic = Field(text)
_last_message_id = Field(snowflake, alias='last_message_id') last_message_id = Field(snowflake)
position = Field(int) position = Field(int)
bitrate = Field(int) bitrate = Field(int)
recipients = Field(listof(User)) recipients = Field(listof(User))
@ -138,15 +138,6 @@ class Channel(Model, Permissible):
""" """
return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM) return self.type in (ChannelType.GUILD_VOICE, ChannelType.GROUP_DM)
@property
def last_message_id(self):
"""
Returns the ID of the last message sent in this channel
"""
if self.id not in self.client.state.messages:
return self._last_message_id
return self.client.state.messages[self.id][-1].id
@property @property
def messages(self): def messages(self):
""" """
@ -159,7 +150,7 @@ class Channel(Model, Permissible):
Creates a new :class:`MessageIterator` for the channel with the given Creates a new :class:`MessageIterator` for the channel with the given
keyword arguments keyword arguments
""" """
return MessageIterator(self.client, self.id, before=self.last_message_id, **kwargs) return MessageIterator(self.client, self.id, **kwargs)
@cached_property @cached_property
def guild(self): def guild(self):
@ -242,9 +233,40 @@ class Channel(Model, Permissible):
def delete_overwrite(self, ow): def delete_overwrite(self, ow):
self.client.api.channels_permissions_delete(self.id, ow.id) self.client.api.channels_permissions_delete(self.id, ow.id)
def delete_messages_bulk(self, messages): def delete_message(self, message):
"""
Deletes a single message from this channel.
Args
----
message : snowflake|:class:`disco.types.message.Message`
The message to delete.
"""
self.client.api.channels_messages_delete(self.id, to_snowflake(message))
@one_or_many
def delete_messages(self, messages):
"""
Deletes a set of messages using the correct API route based on the number
of messages passed.
Args
----
messages : list[snowflake|:class:`disco.types.message.Message`]
List of messages (or message ids) to delete. All messages must originate
from this channel.
"""
messages = map(to_snowflake, messages) messages = map(to_snowflake, messages)
self.client.api.channels_messages_delete_bulk(self.id, messages)
if not messages:
return
if len(messages) <= 2:
for msg in messages:
self.delete_message(msg)
else:
for chunk in chunks(messages, 100):
self.client.api.channels_messages_delete_bulk(self.id, chunk)
class MessageIterator(object): class MessageIterator(object):
@ -283,9 +305,6 @@ class MessageIterator(object):
self.last = None self.last = None
self._buffer = [] self._buffer = []
if not before and not after:
raise Exception('Must specify at most one of before or after')
if not any((before, after)) and self.direction == self.Direction.DOWN: if not any((before, after)) and self.direction == self.Direction.DOWN:
raise Exception('Must specify either before or after for downward seeking') raise Exception('Must specify either before or after for downward seeking')

4
disco/types/user.py

@ -1,7 +1,7 @@
from disco.types.base import Model, Field, snowflake, text, binary from disco.types.base import Model, Field, snowflake, text, binary, with_equality, with_hash
class User(Model): class User(Model, with_equality('id'), with_hash('id')):
id = Field(snowflake) id = Field(snowflake)
username = Field(text) username = Field(text)
discriminator = Field(str) discriminator = Field(str)

9
disco/util/config.py

@ -21,8 +21,8 @@ class Config(object):
data = f.read() data = f.read()
_, ext = os.path.splitext(path) _, ext = os.path.splitext(path)
Serializer.check_format(ext) Serializer.check_format(ext[1:])
inst.__dict__.update(Serializer.load(ext, data)) inst.__dict__.update(Serializer.loads(ext[1:], data))
return inst return inst
def from_prefix(self, prefix): def from_prefix(self, prefix):
@ -33,10 +33,13 @@ class Config(object):
if k.startswith(prefix): if k.startswith(prefix):
obj[k[len(prefix):]] = v obj[k[len(prefix):]] = v
return obj return Config(obj)
def update(self, other): def update(self, other):
if isinstance(other, Config): if isinstance(other, Config):
other = other.__dict__ other = other.__dict__
self.__dict__.update(other) self.__dict__.update(other)
def to_dict(self):
return self.__dict__

49
disco/util/functional.py

@ -1,5 +1,54 @@
from gevent.lock import RLock from gevent.lock import RLock
from six.moves import range
NO_MORE_SENTINEL = object()
def take(seq, count):
"""
Take count many elements from a sequence or generator.
Args
----
seq : sequnce or generator
The sequnce to take elements from.
count : int
The number of elments to take.
"""
for _ in range(count):
i = next(seq, NO_MORE_SENTINEL)
if i is NO_MORE_SENTINEL:
raise StopIteration
yield i
def chunks(obj, size):
"""
Splits a list into sized chunks.
Args
----
obj : list
List to split up.
size : int
Size of chunks to split list into.
"""
for i in range(0, len(obj), size):
yield obj[i:i + size]
def one_or_many(f):
"""
Wraps a function so that it will either take a single argument, or a variable
number of args.
"""
def _f(*args):
if len(args) == 1:
return f(args[0])
return f(*args)
return _f
def cached_property(f): def cached_property(f):
""" """

18
disco/util/snowflake.py

@ -0,0 +1,18 @@
from datetime import datetime
DISCORD_EPOCH = 1420070400000
def to_datetime(snowflake):
"""
Converts a snowflake to a UTC datetime.
"""
return datetime.utcfromtimestamp(to_unix(snowflake))
def to_unix(snowflake):
return to_unix_ms(snowflake) / 1000
def to_unix_ms(snowflake):
return ((int(snowflake) >> 22) + DISCORD_EPOCH)

2
requirements.txt

@ -1,5 +1,5 @@
gevent==1.1.2 gevent==1.1.2
holster==1.0.5 holster==1.0.6
inflection==0.3.1 inflection==0.3.1
requests==2.11.1 requests==2.11.1
six==1.10.0 six==1.10.0

Loading…
Cancel
Save