diff --git a/disco/state.py b/disco/state.py index a08c983..2130787 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,11 +1,12 @@ import six import inflection -from collections import defaultdict, deque, namedtuple +from collections import deque, namedtuple from weakref import WeakValueDictionary from gevent.event import Event from disco.util.config import Config +from disco.util.hashmap import HashMap, DefaultHashMap class StackMessage(namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])): @@ -99,15 +100,15 @@ class State(object): self.guilds_waiting_sync = 0 self.me = None - self.dms = {} - self.guilds = {} - self.channels = WeakValueDictionary() - self.users = WeakValueDictionary() - self.voice_states = WeakValueDictionary() + self.dms = HashMap() + self.guilds = HashMap() + self.channels = HashMap(WeakValueDictionary()) + self.users = HashMap(WeakValueDictionary()) + self.voice_states = HashMap(WeakValueDictionary()) # If message tracking is enabled, listen to those events if self.config.track_messages: - self.messages = defaultdict(lambda: deque(maxlen=self.config.track_messages_size)) + self.messages = DefaultHashMap(lambda: deque(maxlen=self.config.track_messages_size)) self.EVENTS += ['MessageDelete', 'MessageDeleteBulk'] # The bound listener objects diff --git a/disco/types/base.py b/disco/types/base.py index 35600b4..4c98170 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -7,6 +7,7 @@ from holster.enum import BaseEnumMeta from datetime import datetime as real_datetime from disco.util.functional import CachedSlotProperty +from disco.util.hashmap import HashMap DATETIME_FORMATS = [ '%Y-%m-%dT%H:%M:%S.%f', @@ -70,7 +71,7 @@ class Field(FieldType): class _Dict(FieldType): - default = dict + default = HashMap def __init__(self, typ, key=None): super(_Dict, self).__init__(typ) @@ -79,9 +80,9 @@ class _Dict(FieldType): def try_convert(self, raw, client): if self.key: converted = [self.typ(i, client) for i in raw] - return {getattr(i, self.key): i for i in converted} + return HashMap({getattr(i, self.key): i for i in converted}) else: - return {k: self.typ(v, client) for k, v in six.iteritems(raw)} + return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)}) class _List(FieldType): diff --git a/disco/util/hashmap.py b/disco/util/hashmap.py new file mode 100644 index 0000000..7a3ace1 --- /dev/null +++ b/disco/util/hashmap.py @@ -0,0 +1,51 @@ +import six + +from six.moves import filter, map +from collections import defaultdict +from UserDict import IterableUserDict + + +class HashMap(IterableUserDict): + def items(self): + return six.iteritems(self.data) + + def keys(self): + return six.iterkeys(self.data) + + def values(self): + return six.itervalues(self.data) + + def find(self, predicate): + if not callable(predicate): + raise TypeError('predicate must be callable') + + for obj in self.values(): + if predicate(obj): + yield obj + + def find_one(self, predicate): + return next(self.find(predicate), None) + + def select(self, **kwargs): + for obj in self.values(): + for k, v in six.iteritems(kwargs): + if getattr(obj, k) != v: + continue + yield obj + + def select_one(self, **kwargs): + return next(self.select(**kwargs), None) + + def filter(self, predicate): + if not callable(predicate): + raise TypeError('predicate must be callable') + return filter(self.values(), predicate) + + def map(self, predicate): + if not callable(predicate): + raise TypeError('predicate must be callable') + return map(self.values(), predicate) + + +class DefaultHashMap(defaultdict, HashMap): + pass