Browse Source

More hashmaps, cleanup and fixes

pull/11/head
Andrei 9 years ago
parent
commit
492b26326a
  1. 6
      disco/api/client.py
  2. 2
      disco/api/http.py
  3. 3
      disco/bot/bot.py
  4. 2
      disco/bot/command.py
  5. 4
      disco/bot/plugin.py
  6. 7
      disco/types/base.py
  7. 31
      disco/types/message.py
  8. 2
      disco/types/user.py

6
disco/api/client.py

@ -193,7 +193,7 @@ class APIClient(LoggingClass):
def guilds_channels_list(self, guild): def guilds_channels_list(self, guild):
r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_CHANNELS_LIST, dict(guild=guild))
return Channel.create_map(self.client, r.json(), guild_id=guild) return Channel.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_channels_create(self, guild, **kwargs): def guilds_channels_create(self, guild, **kwargs):
r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs) r = self.http(Routes.GUILDS_CHANNELS_CREATE, dict(guild=guild), json=kwargs)
@ -207,7 +207,7 @@ class APIClient(LoggingClass):
def guilds_members_list(self, guild): def guilds_members_list(self, guild):
r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_MEMBERS_LIST, dict(guild=guild))
return GuildMember.create_map(self.client, r.json(), guild_id=guild) return GuildMember.create_hash(self.client, 'id', r.json(), guild_id=guild)
def guilds_members_get(self, guild, member): def guilds_members_get(self, guild, member):
r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member)) r = self.http(Routes.GUILDS_MEMBERS_GET, dict(guild=guild, member=member))
@ -224,7 +224,7 @@ class APIClient(LoggingClass):
def guilds_bans_list(self, guild): def guilds_bans_list(self, guild):
r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild)) r = self.http(Routes.GUILDS_BANS_LIST, dict(guild=guild))
return User.create_map(self.client, r.json()) return User.create_hash(self.client, 'id', r.json())
def guilds_bans_create(self, guild, user, delete_message_days): def guilds_bans_create(self, guild, user, delete_message_days):
self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={ self.http(Routes.GUILDS_BANS_CREATE, dict(guild=guild, user=user), params={

2
disco/api/http.py

@ -211,7 +211,7 @@ class HTTPClient(LoggingClass):
# Make the actual request # Make the actual request
url = self.BASE_URL + route[1].format(**args) url = self.BASE_URL + route[1].format(**args)
self.log.info('%s %s', route[0].value, url) self.log.info('%s %s (%s)', route[0].value, url, kwargs.get('params'))
r = requests.request(route[0].value, url, **kwargs) r = requests.request(route[0].value, url, **kwargs)
# Update rate limiter # Update rate limiter

3
disco/bot/bot.py

@ -214,7 +214,8 @@ class Bot(object):
""" """
Computes a single regex which matches all possible command combinations. Computes a single regex which matches all possible command combinations.
""" """
re_str = '|'.join(command.regex for command in self.commands) commands = list(self.commands)
re_str = '|'.join(command.regex for command in commands)
if re_str: if re_str:
self.command_matches_re = re.compile(re_str) self.command_matches_re = re.compile(re_str)
else: else:

2
disco/bot/command.py

@ -5,7 +5,7 @@ from holster.enum import Enum
from disco.bot.parser import ArgumentSet, ArgumentError from disco.bot.parser import ArgumentSet, ArgumentError
from disco.util.functional import cached_property from disco.util.functional import cached_property
REGEX_FMT = '({})' REGEX_FMT = '{}'
ARGS_REGEX = '( ((?:\n|.)*)$|$)' ARGS_REGEX = '( ((?:\n|.)*)$|$)'
USER_MENTION_RE = re.compile('<@!?([0-9]+)>') USER_MENTION_RE = re.compile('<@!?([0-9]+)>')

4
disco/bot/plugin.py

@ -258,7 +258,7 @@ class Plugin(LoggingClass, PluginDeco):
self.ctx['user'] = event.author self.ctx['user'] = event.author
for pre in self._pre[typ]: for pre in self._pre[typ]:
event = pre(event, args, kwargs) event = pre(func, event, args, kwargs)
if event is None: if event is None:
return False return False
@ -266,7 +266,7 @@ class Plugin(LoggingClass, PluginDeco):
result = func(event, *args, **kwargs) result = func(event, *args, **kwargs)
for post in self._post[typ]: for post in self._post[typ]:
post(event, args, kwargs, result) post(func, event, args, kwargs, result)
return True return True

7
disco/types/base.py

@ -74,7 +74,6 @@ class Field(object):
try: try:
return self.deserializer(raw, client) return self.deserializer(raw, client)
except Exception as e: except Exception as e:
raise
six.reraise(ConversionError, ConversionError(self, raw, e)) six.reraise(ConversionError, ConversionError(self, raw, e))
@staticmethod @staticmethod
@ -337,6 +336,12 @@ class Model(six.with_metaclass(ModelMeta, AsyncChainable)):
def create_map(cls, client, data, **kwargs): def create_map(cls, client, data, **kwargs):
return list(map(functools.partial(cls.create, client, **kwargs), data)) return list(map(functools.partial(cls.create, client, **kwargs), data))
@classmethod
def create_hash(cls, client, key, data, **kwargs):
return HashMap({
getattr(item, key): cls.create(client, item, **kwargs) for item in data
})
@classmethod @classmethod
def attach(cls, it, data): def attach(cls, it, data):
for item in it: for item in it:

31
disco/types/message.py

@ -1,5 +1,7 @@
import re import re
import six
import functools import functools
import unicodedata
from holster.enum import Enum from holster.enum import Enum
@ -105,7 +107,7 @@ class MessageEmbed(SlottedModel):
title = Field(text) title = Field(text)
type = Field(str, default='rich') type = Field(str, default='rich')
description = Field(text) description = Field(text)
url = Field(str) url = Field(text)
timestamp = Field(lazy_datetime) timestamp = Field(lazy_datetime)
color = Field(int) color = Field(int)
footer = Field(MessageEmbedFooter) footer = Field(MessageEmbedFooter)
@ -139,8 +141,8 @@ class MessageAttachment(SlottedModel):
""" """
id = Field(str) id = Field(str)
filename = Field(text) filename = Field(text)
url = Field(str) url = Field(text)
proxy_url = Field(str) proxy_url = Field(text)
size = Field(int) size = Field(int)
height = Field(int) height = Field(int)
width = Field(int) width = Field(int)
@ -327,13 +329,13 @@ class Message(SlottedModel):
@cached_property @cached_property
def with_proper_mentions(self): def with_proper_mentions(self):
def replace_user(u): def replace_user(u):
return '@' + str(u) return u'@' + six.text_type(u)
def replace_role(r): def replace_role(r):
return '@' + str(r) return u'@' + six.text_type(r)
def replace_channel(c): def replace_channel(c):
return str(c) return six.text_type(c)
return self.replace_mentions(replace_user, replace_role, replace_channel) return self.replace_mentions(replace_user, replace_role, replace_channel)
@ -382,25 +384,28 @@ class Message(SlottedModel):
class MessageTable(object): class MessageTable(object):
def __init__(self, sep=' | ', codeblock=True, header_break=True): def __init__(self, sep=' | ', codeblock=True, header_break=True, language=None):
self.header = [] self.header = []
self.entries = [] self.entries = []
self.size_index = {} self.size_index = {}
self.sep = sep self.sep = sep
self.codeblock = codeblock self.codeblock = codeblock
self.header_break = header_break self.header_break = header_break
self.language = language
def recalculate_size_index(self, cols): def recalculate_size_index(self, cols):
for idx, col in enumerate(cols): for idx, col in enumerate(cols):
if idx not in self.size_index or len(col) > self.size_index[idx]: size = len(unicodedata.normalize('NFC', col))
self.size_index[idx] = len(col) if idx not in self.size_index or size > self.size_index[idx]:
self.size_index[idx] = size
def set_header(self, *args): def set_header(self, *args):
args = list(map(six.text_type, args))
self.header = args self.header = args
self.recalculate_size_index(args) self.recalculate_size_index(args)
def add(self, *args): def add(self, *args):
args = list(map(lambda v: v if isinstance(v, basestring) else str(v), args)) args = list(map(six.text_type, args))
self.entries.append(args) self.entries.append(args)
self.recalculate_size_index(args) self.recalculate_size_index(args)
@ -414,15 +419,17 @@ class MessageTable(object):
return data.rstrip() return data.rstrip()
def compile(self): def compile(self):
data = []
if self.header:
data = [self.compile_one(self.header)] data = [self.compile_one(self.header)]
if self.header_break: if self.header and self.header_break:
data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1)) data.append('-' * (sum(self.size_index.values()) + (len(self.header) * len(self.sep)) + 1))
for row in self.entries: for row in self.entries:
data.append(self.compile_one(row)) data.append(self.compile_one(row))
if self.codeblock: if self.codeblock:
return '```' + '\n'.join(data) + '```' return '```{}'.format(self.language if self.language else '') + '\n'.join(data) + '```'
return '\n'.join(data) return '\n'.join(data)

2
disco/types/user.py

@ -28,7 +28,7 @@ class User(SlottedModel, with_equality('id'), with_hash('id')):
return '<@{}>'.format(self.id) return '<@{}>'.format(self.id)
def __str__(self): def __str__(self):
return u'{}#{}'.format(self.username, self.discriminator) return u'{}#{}'.format(self.username, str(self.discriminator).zfill(4))
def __repr__(self): def __repr__(self):
return u'<User {} ({})>'.format(self.id, self) return u'<User {} ({})>'.format(self.id, self)

Loading…
Cancel
Save