diff --git a/disco/gateway/events.py b/disco/gateway/events.py index c882490..d8451ef 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -13,10 +13,10 @@ class GatewayEvent(Model): if not cls: raise Exception('Could not find cls for {}'.format(data['t'])) - return cls.create(data['d']) + return cls.create(data['d'], client) @classmethod - def create(cls, obj): + def create(cls, obj, client): # If this event is wrapping a model, pull its fields if hasattr(cls, '_wraps_model'): alias, model = cls._wraps_model @@ -27,7 +27,7 @@ class GatewayEvent(Model): obj[alias] = data - return cls(obj) + return cls(obj, client) def wraps_model(model, alias=None): diff --git a/disco/state.py b/disco/state.py index f32b4ba..90d2b22 100644 --- a/disco/state.py +++ b/disco/state.py @@ -154,6 +154,10 @@ class State(object): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) + for channel in event.guild.channels.values(): + channel.guild_id = event.guild.id + channel.guild = event.guild + for member in event.guild.members.values(): self.users[member.user.id] = member.user diff --git a/disco/types/base.py b/disco/types/base.py index 39b2f28..0279beb 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -5,28 +5,43 @@ import functools from datetime import datetime as real_datetime +def _make(typ, data, client): + args, _, _, _ = inspect.getargspec(typ) + if 'client' in args: + return typ(data, client) + return typ(data) + + def snowflake(data): - return int(data) + return int(data) if data else None def enum(typ): def _f(data): - return typ.get(data) + return typ.get(data) if data else None return _f def listof(typ): - def _f(data): - return list(map(typ, data)) + def _f(data, client=None): + if not data: + return [] + return [_make(typ, obj, client) for obj in data] return _f def dictof(typ, key=None): - def _f(data): + def _f(data, client=None): + if not data: + return {} + if key: - return {getattr(v, key): v for v in map(typ, data)} + return { + getattr(v, key): v for v in ( + _make(typ, i, client) for i in data + )} else: - return {k: typ(v) for k, v in six.iteritems(data)} + return {k: _make(typ, v, client) for k, v in six.iteritems(data)} return _f @@ -34,16 +49,16 @@ def alias(typ, name): return ('alias', name, typ) -def datetime(typ): - return real_datetime.strptime(typ.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') +def datetime(data): + return real_datetime.strptime(data.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') if data else None def text(obj): - return six.text_type(obj) + return six.text_type(obj) if obj else six.text_type() def binary(obj): - return six.text_type(obj) + return six.text_type(obj) if obj else six.text_type() class ModelMeta(type): @@ -55,7 +70,13 @@ class ModelMeta(type): fields[v[1]] = (k, v[2]) continue - if callable(v) or inspect.isclass(v): + if inspect.isclass(v): + fields[k] = v + elif callable(v): + args, _, _, _ = inspect.getargspec(v) + if 'self' in args: + continue + fields[k] = v dct['_fields'] = fields @@ -64,6 +85,8 @@ class ModelMeta(type): class Model(six.with_metaclass(ModelMeta)): def __init__(self, obj, client=None): + self.client = client + for name, typ in self.__class__._fields.items(): dest_name = name @@ -71,10 +94,24 @@ class Model(six.with_metaclass(ModelMeta)): dest_name, typ = typ if name not in obj or not obj[name]: + if inspect.isclass(typ) and issubclass(typ, Model): + res = None + elif isinstance(typ, type): + res = typ() + else: + res = typ(None) + setattr(self, dest_name, res) continue try: - v = typ(obj[name]) + if client: + args, _, _, _ = inspect.getargspec(typ) + if 'client' in args: + v = typ(obj[name], client) + else: + v = typ(obj[name]) + else: + v = typ(obj[name]) except TypeError as e: print('Failed during parsing of field {} => {} (`{}`)'.format(name, typ, obj[name])) raise e @@ -92,7 +129,7 @@ class Model(six.with_metaclass(ModelMeta)): @classmethod def create(cls, client, data): - return cls(data) + return cls(data, client) @classmethod def create_map(cls, client, data):