Browse Source

Code cleanliness pass

pull/9/head
andrei 9 years ago
parent
commit
2a311cc336
  1. 4
      disco/api/client.py
  2. 9
      disco/api/http.py
  3. 5
      disco/bot/bot.py
  4. 21
      disco/bot/command.py
  5. 23
      disco/bot/plugin.py
  6. 1
      disco/bot/providers/disk.py
  7. 5
      disco/bot/providers/redis.py
  8. 6
      disco/bot/providers/rocksdb.py
  9. 4
      disco/client.py
  10. 6
      disco/gateway/client.py
  11. 2
      disco/gateway/encoding/json.py
  12. 4
      disco/gateway/events.py
  13. 26
      disco/gateway/sharder.py
  14. 10
      disco/types/base.py
  15. 12
      disco/types/channel.py
  16. 2
      disco/types/guild.py
  17. 17
      disco/types/message.py
  18. 3
      disco/types/permissions.py
  19. 6
      disco/types/webhook.py
  20. 2
      disco/util/config.py
  21. 2
      disco/util/snowflake.py
  22. 14
      disco/voice/client.py

4
disco/api/client.py

@ -88,8 +88,8 @@ class APIClient(LoggingClass):
def channels_messages_modify(self, channel, message, content):
r = self.http(Routes.CHANNELS_MESSAGES_MODIFY,
dict(channel=channel, message=message),
json={'content': content})
dict(channel=channel, message=message),
json={'content': content})
return Message.create(self.client, r.json())
def channels_messages_delete(self, channel, message):

9
disco/api/http.py

@ -48,7 +48,8 @@ class Routes(object):
CHANNELS_MESSAGES_REACTIONS_GET = (HTTPMethod.GET, CHANNELS + '/messages/{message}/reactions/{emoji}')
CHANNELS_MESSAGES_REACTIONS_CREATE = (HTTPMethod.PUT, CHANNELS + '/messages/{message}/reactions/{emoji}/@me')
CHANNELS_MESSAGES_REACTIONS_DELETE_ME = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/@me')
CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE, CHANNELS + '/messages/{message}/reactions/{emoji}/{user}')
CHANNELS_MESSAGES_REACTIONS_DELETE_USER = (HTTPMethod.DELETE,
CHANNELS + '/messages/{message}/reactions/{emoji}/{user}')
CHANNELS_PERMISSIONS_MODIFY = (HTTPMethod.PUT, CHANNELS + '/permissions/{permission}')
CHANNELS_PERMISSIONS_DELETE = (HTTPMethod.DELETE, CHANNELS + '/permissions/{permission}')
CHANNELS_INVITES_LIST = (HTTPMethod.GET, CHANNELS + '/invites')
@ -222,13 +223,15 @@ class HTTPClient(LoggingClass):
raise APIException('Request failed', r.status_code, r.content)
else:
if r.status_code == 429:
self.log.warning('Request responded w/ 429, retrying (but this should not happen, check your clock sync')
self.log.warning(
'Request responded w/ 429, retrying (but this should not happen, check your clock sync')
# If we hit the max retries, throw an error
retry += 1
if retry > self.MAX_RETRIES:
self.log.error('Failing request, hit max retries')
raise APIException('Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content)
raise APIException(
'Request failed after {} attempts'.format(self.MAX_RETRIES), r.status_code, r.content)
backoff = self.random_backoff()
self.log.warning('Request to `{}` failed with code {}, retrying after {}s ({})'.format(

5
disco/bot/bot.py

@ -245,7 +245,7 @@ class Bot(object):
mention_roles = []
if msg.guild:
mention_roles = list(filter(lambda r: msg.is_mentioned(r),
msg.guild.get_member(self.client.state.me).roles))
msg.guild.get_member(self.client.state.me).roles))
if not any((
self.config.commands_mention_rules['user'] and mention_direct,
@ -370,6 +370,9 @@ class Bot(object):
Plugin class to initialize and load.
config : Optional
The configuration to load the plugin with.
ctx : Optional[dict]
Context (previous state) to pass the plugin. Usually used along w/
unload.
"""
if cls.__name__ in self.plugins:
raise Exception('Cannot add already added plugin: {}'.format(cls.__name__))

21
disco/bot/command.py

@ -107,16 +107,23 @@ class Command(object):
self.plugin = plugin
self.func = func
self.triggers = [trigger]
self.args = None
self.level = None
self.group = None
self.is_regex = None
self.oob = False
self.update(*args, **kwargs)
def update(self, args=None, level=None, aliases=None, group=None, is_regex=None, oob=False):
self.triggers += aliases or []
def resolve_role(ctx, id):
return ctx.msg.guild.roles.get(id)
def resolve_role(ctx, rid):
return ctx.msg.guild.roles.get(rid)
def resolve_user(ctx, id):
return ctx.msg.mentions.get(id)
def resolve_user(ctx, uid):
return ctx.msg.mentions.get(uid)
self.args = ArgumentSet.from_string(args or '', {
'mention': self.mention_type([resolve_role, resolve_user]),
@ -136,17 +143,17 @@ class Command(object):
if not res:
raise TypeError('Invalid mention: {}'.format(i))
id = int(res.group(1))
mid = int(res.group(1))
for getter in getters:
obj = getter(ctx, id)
obj = getter(ctx, mid)
if obj:
return obj
if force:
raise TypeError('Cannot resolve mention: {}'.format(id))
return id
return mid
return _f
@cached_property

23
disco/bot/plugin.py

@ -154,6 +154,14 @@ class Plugin(LoggingClass, PluginDeco):
self.storage = bot.storage
self.config = config
# General declartions
self.listeners = []
self.commands = {}
self.schedules = {}
self.greenlets = weakref.WeakSet()
self._pre = {}
self._post = {}
# This is an array of all meta functions we sniff at init
self.meta_funcs = []
@ -248,7 +256,7 @@ class Plugin(LoggingClass, PluginDeco):
return True
def register_listener(self, func, what, desc, priority=Priority.NONE, conditional=None):
def register_listener(self, func, what, desc, **kwargs):
"""
Registers a listener
@ -260,15 +268,13 @@ class Plugin(LoggingClass, PluginDeco):
The function to be registered.
desc
The descriptor of the event/packet.
priority : Priority
The priority of this listener.
"""
func = functools.partial(self._dispatch, 'listener', func)
if what == 'event':
li = self.bot.client.events.on(desc, func, priority=priority, conditional=conditional)
li = self.bot.client.events.on(desc, func, **kwargs)
elif what == 'packet':
li = self.bot.client.packets.on(desc, func, priority=priority, conditional=conditional)
li = self.bot.client.packets.on(desc, func, **kwargs)
else:
raise Exception('Invalid listener what: {}'.format(what))
@ -305,8 +311,13 @@ class Plugin(LoggingClass, PluginDeco):
The function to be registered.
interval : int
Interval (in seconds) to repeat the function on.
repeat : bool
Whether this schedule is repeating (or one time).
init : bool
Whether to run this schedule once immediatly, or wait for the first
scheduled iteration.
"""
def repeat():
def func():
if init:
func()

1
disco/bot/providers/disk.py

@ -13,6 +13,7 @@ class DiskProvider(BaseProvider):
self.fsync = config.get('fsync', False)
self.fsync_changes = config.get('fsync_changes', 1)
self.autosave_task = None
self.change_count = 0
def autosave_loop(self, interval):

5
disco/bot/providers/redis.py

@ -10,8 +10,9 @@ from .base import BaseProvider, SEP_SENTINEL
class RedisProvider(BaseProvider):
def __init__(self, config):
self.config = config
super(RedisProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.conn = None
def load(self):
self.conn = redis.Redis(
@ -39,5 +40,5 @@ class RedisProvider(BaseProvider):
def set(self, key, value):
self.conn.set(key, Serializer.dumps(self.format, value))
def delete(self, key, value):
def delete(self, key):
self.conn.delete(key)

6
disco/bot/providers/rocksdb.py

@ -12,11 +12,13 @@ from .base import BaseProvider, SEP_SENTINEL
class RocksDBProvider(BaseProvider):
def __init__(self, config):
self.config = config
super(RocksDBProvider, self).__init__(config)
self.format = config.get('format', 'pickle')
self.path = config.get('path', 'storage.db')
self.db = None
def k(self, k):
@staticmethod
def k(k):
return bytes(k) if six.PY3 else str(k.encode('utf-8'))
def load(self):

4
disco/client.py

@ -98,8 +98,8 @@ class Client(LoggingClass):
}
self.manhole = DiscoBackdoorServer(self.config.manhole_bind,
banner='Disco Manhole',
localf=lambda: self.manhole_locals)
banner='Disco Manhole',
localf=lambda: self.manhole_locals)
self.manhole.start()
def update_presence(self, game=None, status=None, afk=False, since=0.0):

6
disco/gateway/client.py

@ -81,15 +81,15 @@ class GatewayClient(LoggingClass):
self.log.debug('Dispatching %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet):
def handle_heartbeat(self, _):
self._send(OPCode.HEARTBEAT, self.seq)
def handle_reconnect(self, packet):
def handle_reconnect(self, _):
self.log.warning('Received RECONNECT request, forcing a fresh reconnect')
self.session_id = None
self.ws.close()
def handle_invalid_session(self, packet):
def handle_invalid_session(self, _):
self.log.warning('Recieved INVALID_SESSION, forcing a fresh reconnect')
self.session_id = None
self.ws.close()

2
disco/gateway/encoding/json.py

@ -1,7 +1,5 @@
from __future__ import absolute_import, print_function
import six
try:
import ujson as json
except ImportError:

4
disco/gateway/events.py

@ -16,8 +16,8 @@ EVENTS_MAP = {}
class GatewayEventMeta(ModelMeta):
def __new__(cls, name, parents, dct):
obj = super(GatewayEventMeta, cls).__new__(cls, name, parents, dct)
def __new__(mcs, name, parents, dct):
obj = super(GatewayEventMeta, mcs).__new__(mcs, name, parents, dct)
if name != 'GatewayEvent':
EVENTS_MAP[inflection.underscore(name).upper()] = obj

26
disco/gateway/sharder.py

@ -5,6 +5,8 @@ import gevent
import logging
import marshal
from six.moves import range
from disco.client import Client
from disco.bot import Bot, BotConfig
from disco.api.client import APIClient
@ -14,13 +16,13 @@ from disco.util.snowflake import calculate_shard
from disco.util.serializer import dump_function, load_function
def run_shard(config, id, pipe):
def run_shard(config, shard_id, pipe):
setup_logging(
level=logging.INFO,
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id)
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(shard_id)
)
config.shard_id = id
config.shard_id = shard_id
client = Client(config)
bot = Bot(client, BotConfig(config.bot))
bot.sharder = GIPCProxy(bot, pipe)
@ -34,8 +36,8 @@ class ShardHelper(object):
self.bot = bot
def keys(self):
for id in xrange(self.count):
yield id
for sid in range(self.count):
yield sid
def on(self, id, func):
if id == self.bot.client.config.shard_id:
@ -49,8 +51,8 @@ class ShardHelper(object):
pool = gevent.pool.Pool(self.count)
return dict(zip(range(self.count), pool.imap(lambda i: self.on(i, func).wait(timeout=timeout), range(self.count))))
def for_id(self, id, func):
shard = calculate_shard(self.count, id)
def for_id(self, sid, func):
shard = calculate_shard(self.count, sid)
return self.on(shard, func)
@ -63,9 +65,9 @@ class AutoSharder(object):
if self.config.shard_count > 1:
self.config.shard_count = 10
def run_on(self, id, raw):
def run_on(self, sid, raw):
func = load_function(raw)
return self.shards[id].execute(func).wait(timeout=15)
return self.shards[sid].execute(func).wait(timeout=15)
def run(self):
for shard in range(self.config.shard_count):
@ -80,7 +82,7 @@ class AutoSharder(object):
format='{} [%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s'.format(id)
)
def start_shard(self, id):
def start_shard(self, sid):
cpipe, ppipe = gipc.pipe(duplex=True, encoder=marshal.dumps, decoder=marshal.loads)
gipc.start_process(run_shard, (self.config, id, cpipe))
self.shards[id] = GIPCProxy(self, ppipe)
gipc.start_process(run_shard, (self.config, sid, cpipe))
self.shards[sid] = GIPCProxy(self, ppipe)

10
disco/types/base.py

@ -179,7 +179,7 @@ def with_equality(field):
def with_hash(field):
class T(object):
def __hash__(self, other):
def __hash__(self):
return hash(getattr(self, field))
return T
@ -190,7 +190,7 @@ SlottedModel = None
class ModelMeta(type):
def __new__(cls, name, parents, dct):
def __new__(mcs, name, parents, dct):
fields = {}
for parent in parents:
@ -217,7 +217,7 @@ class ModelMeta(type):
dct = {k: v for k, v in six.iteritems(dct) if k not in fields}
dct['_fields'] = fields
return super(ModelMeta, cls).__new__(cls, name, parents, dct)
return super(ModelMeta, mcs).__new__(mcs, name, parents, dct)
class AsyncChainable(object):
@ -280,8 +280,8 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
return inst
@classmethod
def create_map(cls, client, data):
return list(map(functools.partial(cls.create, client), data))
def create_map(cls, client, data, **kwargs):
return list(map(functools.partial(cls.create, client, **kwargs), data))
@classmethod
def attach(cls, it, data):

12
disco/types/channel.py

@ -57,11 +57,11 @@ class PermissionOverwrite(ChannelSubType):
def create(cls, channel, entity, allow=0, deny=0):
from disco.types.guild import Role
type = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER
ptype = PermissionOverwriteType.ROLE if isinstance(entity, Role) else PermissionOverwriteType.MEMBER
return cls(
client=channel.client,
id=entity.id,
type=type,
type=ptype,
allow=allow,
deny=deny,
channel_id=channel.id
@ -69,10 +69,10 @@ class PermissionOverwrite(ChannelSubType):
def save(self):
self.client.api.channels_permissions_modify(self.channel_id,
self.id,
self.allow.value or 0,
self.deny.value or 0,
self.type.name)
self.id,
self.allow.value or 0,
self.deny.value or 0,
self.type.name)
return self
def delete(self):

2
disco/types/guild.py

@ -41,6 +41,7 @@ class GuildEmoji(Emoji):
Roles this emoji is attached to.
"""
id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text)
require_colons = Field(bool)
managed = Field(bool)
@ -73,6 +74,7 @@ class Role(SlottedModel):
The position of this role in the hierarchy.
"""
id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text)
hoist = Field(bool)
managed = Field(bool)

17
disco/types/message.py

@ -303,8 +303,8 @@ class Message(SlottedModel):
bool
Whether the give entity was mentioned.
"""
id = to_snowflake(entity)
return id in self.mentions or id in self.mention_roles
entity = to_snowflake(entity)
return entity in self.mentions or entity in self.mention_roles
@cached_property
def without_mentions(self):
@ -340,11 +340,11 @@ class Message(SlottedModel):
return
def replace(match):
id = match.group(0)
if id in self.mention_roles:
return role_replace(id)
oid = match.group(0)
if oid in self.mention_roles:
return role_replace(oid)
else:
return user_replace(self.mentions.get(id))
return user_replace(self.mentions.get(oid))
return re.sub('<@!?([0-9]+)>', replace, self.content)
@ -376,14 +376,13 @@ class MessageTable(object):
data = self.sep.lstrip()
for idx, col in enumerate(cols):
padding = ' ' * ((self.size_index[idx] - len(col)))
padding = ' ' * (self.size_index[idx] - len(col))
data += col + padding + self.sep
return data.rstrip()
def compile(self):
data = []
data.append(self.compile_one(self.header))
data = [self.compile_one(self.header)]
if self.header_break:
data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1))

3
disco/types/permissions.py

@ -107,6 +107,9 @@ class PermissionValue(object):
class Permissible(object):
__slots__ = []
def get_permissions(self):
raise NotImplementedError
def can(self, user, *args):
perms = self.get_permissions(user)
return perms.administrator or perms.can(*args)

6
disco/types/webhook.py

@ -32,12 +32,14 @@ class Webhook(SlottedModel):
else:
return self.client.api.webhooks_modify(self.id, name, avatar)
def execute(self, content=None, username=None, avatar_url=None, tts=False, file=None, embeds=[], wait=False):
def execute(self, content=None, username=None, avatar_url=None, tts=False, fobj=None, embeds=[], wait=False):
# TODO: support file stuff properly
return self.client.api.webhooks_token_execute(self.id, self.token, {
'content': content,
'username': username,
'avatar_url': avatar_url,
'tts': tts,
'file': file,
'file': fobj,
'embeds': [i.to_dict() for i in embeds],
}, wait)

2
disco/util/config.py

@ -29,7 +29,7 @@ class Config(object):
return inst
def from_prefix(self, prefix):
prefix = prefix + '_'
prefix += '_'
obj = {}
for k, v in six.iteritems(self.__dict__):

2
disco/util/snowflake.py

@ -17,7 +17,7 @@ def to_unix(snowflake):
def to_unix_ms(snowflake):
return ((int(snowflake) >> 22) + DISCORD_EPOCH)
return (int(snowflake) >> 22) + DISCORD_EPOCH
def to_snowflake(i):

14
disco/voice/client.py

@ -102,6 +102,7 @@ class VoiceClient(LoggingClass):
self.endpoint = None
self.ssrc = None
self.port = None
self.udp = None
self.update_listener = None
@ -149,7 +150,7 @@ class VoiceClient(LoggingClass):
}
})
def on_voice_sdp(self, data):
def on_voice_sdp(self, _):
# Toggle speaking state so clients learn of our SSRC
self.set_speaking(True)
self.set_speaking(False)
@ -178,19 +179,18 @@ class VoiceClient(LoggingClass):
)
self.ws.run_forever()
def on_message(self, ws, msg):
def on_message(self, _, msg):
try:
data = self.encoder.decode(msg)
self.packets.emit(VoiceOPCode[data['op']], data['d'])
except:
self.log.exception('Failed to parse voice gateway message: ')
self.packets.emit(VoiceOPCode[data['op']], data['d'])
def on_error(self, ws, err):
def on_error(self, _, err):
# TODO
self.log.warning('Voice websocket error: {}'.format(err))
def on_open(self, ws):
def on_open(self, _):
self.send(VoiceOPCode.IDENTIFY, {
'server_id': self.channel.guild_id,
'user_id': self.client.state.me.id,
@ -198,7 +198,7 @@ class VoiceClient(LoggingClass):
'token': self.token
})
def on_close(self, ws, code, error):
def on_close(self, _, code, error):
# TODO
self.log.warning('Voice websocket disconnected (%s, %s)', code, error)

Loading…
Cancel
Save