From 4b9a7f61ff52affe33ce3fad27974c658da8e28b Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 7 Oct 2016 02:43:32 -0500 Subject: [PATCH 1/3] Remove skema in favor of simple custom data modeling stuff --- disco/gateway/events.py | 107 +++++++++++++++++-------------------- disco/types/base.py | 99 ++++++++++++++++++++++++++++------ disco/types/channel.py | 44 +++++++-------- disco/types/guild.py | 93 +++++++++++++++----------------- disco/types/invite.py | 28 ++++------ disco/types/message.py | 67 +++++++++++------------ disco/types/permissions.py | 7 --- disco/types/user.py | 20 +++---- disco/types/voice.py | 30 +++++------ disco/util/__init__.py | 36 ------------- disco/util/types.py | 36 ------------- requirements.txt | 1 - 12 files changed, 258 insertions(+), 310 deletions(-) delete mode 100644 disco/util/types.py diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 5d1314d..c882490 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -1,25 +1,19 @@ import inflection -import skema import six -from disco.util import skema_find_recursive_by_type from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState +from disco.types.base import Model, snowflake, alias, listof # TODO: clean this... use BaseType, etc -class GatewayEvent(skema.Model): +class GatewayEvent(Model): @staticmethod def from_dispatch(client, data): cls = globals().get(inflection.camelize(data['t'].lower())) if not cls: raise Exception('Could not find cls for {}'.format(data['t'])) - obj = cls.create(data['d']) - - for field, value in skema_find_recursive_by_type(obj, skema.ModelType): - value.client = client - - return obj + return cls.create(data['d']) @classmethod def create(cls, obj): @@ -28,31 +22,29 @@ class GatewayEvent(skema.Model): alias, model = cls._wraps_model data = { - k: obj.pop(k) for k in six.iterkeys(model._fields_by_stored_name) if k in obj + k: obj.pop(k) for k in six.iterkeys(model._fields) if k in obj } obj[alias] = data - self = cls(obj) - self.validate() - return self + return cls(obj) def wraps_model(model, alias=None): alias = alias or model.__name__.lower() def deco(cls): - cls.add_field(alias, skema.ModelType(model)) + cls._fields[alias] = model cls._wraps_model = (alias, model) return cls return deco class Ready(GatewayEvent): - version = skema.IntType(stored_name='v') - session_id = skema.StringType() - user = skema.ModelType(User) - guilds = skema.ListType(skema.ModelType(Guild)) + version = alias(int, 'v') + session_id = str + user = User + guilds = listof(Guild) class Resumed(GatewayEvent): @@ -61,17 +53,17 @@ class Resumed(GatewayEvent): @wraps_model(Guild) class GuildCreate(GatewayEvent): - unavailable = skema.BooleanType(default=None) + unavailable = bool @wraps_model(Guild) class GuildUpdate(GatewayEvent): - guild = skema.ModelType(Guild) + pass class GuildDelete(GatewayEvent): - id = skema.SnowflakeType() - unavailable = skema.BooleanType(default=None) + id = snowflake + unavailable = bool @wraps_model(Channel) @@ -90,8 +82,8 @@ class ChannelDelete(ChannelCreate): class ChannelPinsUpdate(GatewayEvent): - channel_id = skema.SnowflakeType() - last_pin_timestamp = skema.IntType() + channel_id = snowflake + last_pin_timestamp = int @wraps_model(User) @@ -112,8 +104,8 @@ class GuildIntegrationsUpdate(GatewayEvent): class GuildMembersChunk(GatewayEvent): - guild_id = skema.SnowflakeType() - members = skema.ListType(skema.ModelType(GuildMember)) + guild_id = snowflake + members = listof(GuildMember) @wraps_model(GuildMember, alias='member') @@ -122,29 +114,27 @@ class GuildMemberAdd(GatewayEvent): class GuildMemberRemove(GatewayEvent): - guild_id = skema.SnowflakeType() - user = skema.ModelType(User) + guild_id = snowflake + user = User class GuildMemberUpdate(GatewayEvent): - guild_id = skema.SnowflakeType() - user = skema.ModelType(User) - roles = skema.ListType(skema.SnowflakeType()) + guild_id = snowflake + user = User + roles = listof(snowflake) class GuildRoleCreate(GatewayEvent): - guild_id = skema.SnowflakeType() - role = skema.ModelType(Role) + guild_id = snowflake + role = Role -class GuildRoleUpdate(GatewayEvent): - guild_id = skema.SnowflakeType() - role = skema.ModelType(Role) +class GuildRoleUpdate(GuildRoleCreate): + pass -class GuildRoleDelete(GatewayEvent): - guild_id = skema.SnowflakeType() - role = skema.ModelType(Role) +class GuildRoleDelete(GuildRoleCreate): + pass @wraps_model(Message) @@ -159,32 +149,33 @@ class MessageUpdate(MessageCreate): class MessageDelete(GatewayEvent): - id = skema.SnowflakeType() - channel_id = skema.SnowflakeType() + id = snowflake + channel_id = snowflake class MessageDeleteBulk(GatewayEvent): - channel_id = skema.SnowflakeType() - ids = skema.ListType(skema.SnowflakeType()) + channel_id = snowflake + ids = listof(snowflake) class PresenceUpdate(GatewayEvent): - class Game(skema.Model): - type = skema.IntType() - name = skema.StringType() - url = skema.StringType(required=False) + class Game(Model): + # TODO enum + type = int + name = str + url = str - user = skema.ModelType(User) - guild_id = skema.SnowflakeType() - roles = skema.ListType(skema.SnowflakeType()) - game = skema.ModelType(Game) - status = skema.StringType() + user = User + guild_id = snowflake + roles = listof(snowflake) + game = Game + status = str class TypingStart(GatewayEvent): - channel_id = skema.SnowflakeType() - user_id = skema.SnowflakeType() - timestamp = skema.IntType() + channel_id = snowflake + user_id = snowflake + timestamp = snowflake @wraps_model(VoiceState, alias='state') @@ -193,6 +184,6 @@ class VoiceStateUpdate(GatewayEvent): class VoiceServerUpdate(GatewayEvent): - token = skema.StringType() - endpoint = skema.StringType() - guild_id = skema.SnowflakeType() + token = str + endpoint = str + guild_id = snowflake diff --git a/disco/types/base.py b/disco/types/base.py index 2f9c41c..39b2f28 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -1,29 +1,98 @@ -import skema +import six +import inspect import functools -from disco.util import skema_find_recursive_by_type -# from disco.util.types import DeferredModel +from datetime import datetime as real_datetime -class BaseType(skema.Model): +def snowflake(data): + return int(data) + + +def enum(typ): + def _f(data): + return typ.get(data) + return _f + + +def listof(typ): + def _f(data): + return list(map(typ, data)) + return _f + + +def dictof(typ, key=None): + def _f(data): + if key: + return {getattr(v, key): v for v in map(typ, data)} + else: + return {k: typ(v) for k, v in six.iteritems(data)} + return _f + + +def alias(typ, name): + return ('alias', name, typ) + + +def datetime(typ): + return real_datetime.strptime(typ.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') + + +def text(obj): + return six.text_type(obj) + + +def binary(obj): + return six.text_type(obj) + + +class ModelMeta(type): + def __new__(cls, name, parents, dct): + fields = {} + for k, v in six.iteritems(dct): + if isinstance(v, tuple): + if v[0] == 'alias': + fields[v[1]] = (k, v[2]) + continue + + if callable(v) or inspect.isclass(v): + fields[k] = v + + dct['_fields'] = fields + return super(ModelMeta, cls).__new__(cls, name, parents, dct) + + +class Model(six.with_metaclass(ModelMeta)): + def __init__(self, obj, client=None): + for name, typ in self.__class__._fields.items(): + dest_name = name + + if isinstance(typ, tuple): + dest_name, typ = typ + + if name not in obj or not obj[name]: + continue + + try: + v = typ(obj[name]) + except TypeError as e: + print('Failed during parsing of field {} => {} (`{}`)'.format(name, typ, obj[name])) + raise e + + if client and isinstance(v, Model): + v.client = client + + setattr(self, dest_name, v) + def update(self, other): - for name, field in other.__class__._fields.items(): + for name in six.iterkeys(self.__class__.fields): value = getattr(other, name) if value: setattr(self, name, value) @classmethod def create(cls, client, data): - obj = cls(data) - - # Valdiate - obj.validate() - - for field, value in skema_find_recursive_by_type(obj, skema.ModelType): - value.client = client - - obj.client = client - return obj + return cls(data) @classmethod def create_map(cls, client, data): diff --git a/disco/types/channel.py b/disco/types/channel.py index 6ef1eb1..030dbf9 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,12 +1,11 @@ -import skema - from holster.enum import Enum +from disco.types.base import Model, snowflake, enum, listof, dictof, alias, text +from disco.types.permissions import PermissionValue + from disco.util.functional import cached_property -from disco.util.types import ListToDictType -from disco.types.base import BaseType from disco.types.user import User -from disco.types.permissions import PermissionType, Permissions, Permissible +from disco.types.permissions import Permissions, Permissible from disco.voice.client import VoiceClient @@ -23,7 +22,7 @@ PermissionOverwriteType = Enum( ) -class PermissionOverwrite(BaseType): +class PermissionOverwrite(Model): """ A PermissionOverwrite for a :class:`Channel` @@ -39,14 +38,14 @@ class PermissionOverwrite(BaseType): denied : :class:`PermissionValue` All denied permissions """ - id = skema.SnowflakeType() - type = skema.StringType(choices=PermissionOverwriteType.ALL_VALUES) - allow = PermissionType() - deny = PermissionType() + id = snowflake + type = enum(PermissionOverwriteType) + allow = PermissionValue + deny = PermissionValue -class Channel(BaseType, Permissible): +class Channel(Model, Permissible): """ Represents a Discord Channel @@ -71,19 +70,16 @@ class Channel(BaseType, Permissible): overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`) Channel permissions overwrites. """ - id = skema.SnowflakeType() - guild_id = skema.SnowflakeType(required=False) - - name = skema.StringType() - topic = skema.StringType() - _last_message_id = skema.SnowflakeType(stored_name='last_message_id') - position = skema.IntType() - bitrate = skema.IntType(required=False) - - recipients = skema.ListType(skema.ModelType(User)) - type = skema.IntType(choices=ChannelType.ALL_VALUES) - - overwrites = ListToDictType('id', skema.ModelType(PermissionOverwrite), stored_name='permission_overwrites') + id = snowflake + guild_id = snowflake + name = text + topic = text + _last_message_id = alias(snowflake, 'last_message_id') + position = int + bitrate = int + recipients = listof(User) + type = enum(ChannelType) + overwrites = alias(dictof(PermissionOverwrite, key='id'), 'permission_overwrites') def get_permissions(self, user): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index dd437fe..9cc6943 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,17 +1,15 @@ -import skema -import copy + +from disco.types.base import Model, snowflake, listof, dictof, datetime, text, binary from disco.api.http import APIException from disco.util import to_snowflake -from disco.util.types import PreHookType, ListToDictType -from disco.types.base import BaseType from disco.types.user import User from disco.types.voice import VoiceState -from disco.types.permissions import PermissionType, PermissionValue, Permissions, Permissible +from disco.types.permissions import PermissionValue, Permissions, Permissible from disco.types.channel import Channel -class Emoji(BaseType): +class Emoji(Model): """ An emoji object @@ -28,14 +26,14 @@ class Emoji(BaseType): roles : list(snowflake) Roles this emoji is attached to. """ - id = skema.SnowflakeType() - name = skema.StringType() - require_colons = skema.BooleanType() - managed = skema.BooleanType() - roles = skema.ListType(skema.SnowflakeType()) + id = snowflake + name = text + require_colons = bool + managed = bool + roles = listof(snowflake) -class Role(BaseType): +class Role(Model): """ A role object @@ -56,16 +54,16 @@ class Role(BaseType): position : int The position of this role in the hierarchy. """ - id = skema.SnowflakeType() - name = skema.StringType() - hoist = skema.BooleanType() - managed = skema.BooleanType() - color = skema.IntType() - permissions = PermissionType() - position = skema.IntType() + id = snowflake + name = text + hoist = bool + managed = bool + color = int + permissions = PermissionValue + position = int -class GuildMember(BaseType): +class GuildMember(Model): """ A GuildMember object @@ -84,12 +82,12 @@ class GuildMember(BaseType): roles : list(snowflake) Roles this member is part of. """ - user = skema.ModelType(User) - guild_id = skema.SnowflakeType(required=False) - mute = skema.BooleanType() - deaf = skema.BooleanType() - joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) - roles = skema.ListType(skema.SnowflakeType()) + user = User + guild_id = snowflake + mute = bool + deaf = bool + joined_at = datetime + roles = listof(snowflake) def get_voice_state(self): """ @@ -126,7 +124,7 @@ class GuildMember(BaseType): return self.user.id -class Guild(BaseType, Permissible): +class Guild(Model, Permissible): """ A guild object @@ -170,29 +168,24 @@ class Guild(BaseType, Permissible): All of the guilds voice states. """ - id = skema.SnowflakeType() - - owner_id = skema.SnowflakeType() - afk_channel_id = skema.SnowflakeType() - embed_channel_id = skema.SnowflakeType() - - name = skema.StringType() - icon = skema.BinaryType(None) - splash = skema.BinaryType(None) - region = skema.StringType() - - afk_timeout = skema.IntType() - embed_enabled = skema.BooleanType() - verification_level = skema.IntType() - mfa_level = skema.IntType() - - features = skema.ListType(skema.StringType()) - - members = ListToDictType('id', skema.ModelType(copy.deepcopy(GuildMember))) - channels = ListToDictType('id', skema.ModelType(Channel)) - roles = ListToDictType('id', skema.ModelType(Role)) - emojis = ListToDictType('id', skema.ModelType(Emoji)) - voice_states = ListToDictType('session_id', skema.ModelType(VoiceState)) + id = snowflake + owner_id = snowflake + afk_channel_id = snowflake + embed_channel_id = snowflake + name = text + icon = binary + splash = binary + region = str + afk_timeout = int + embed_enabled = bool + verification_level = int + mfa_level = int + features = listof(str) + members = dictof(GuildMember, key='id') + channels = dictof(Channel, key='id') + roles = dictof(Role, key='id') + emojis = dictof(Emoji, key='id') + voice_states = dictof(VoiceState, key='session_id') def get_permissions(self, user): """ diff --git a/disco/types/invite.py b/disco/types/invite.py index cb7a737..f8fbc43 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,13 +1,10 @@ -import skema - -from disco.util.types import PreHookType -from disco.types.base import BaseType +from disco.types.base import Model, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel -class Invite(BaseType): +class Invite(Model): """ An invite object @@ -32,15 +29,12 @@ class Invite(BaseType): created_at : datetime When this invite was created. """ - code = skema.StringType() - - inviter = skema.ModelType(User) - guild = skema.ModelType(Guild) - channel = skema.ModelType(Channel) - - max_age = skema.IntType() - max_uses = skema.IntType() - uses = skema.IntType() - temporary = skema.BooleanType() - - created_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) + code = str + inviter = User + guild = Guild + channel = Channel + max_age = int + max_uses = int + uses = int + temporary = bool + created_at = datetime diff --git a/disco/types/message.py b/disco/types/message.py index 4f171be..87d62c3 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,14 +1,13 @@ import re -import skema + +from disco.types.base import Model, snowflake, text, datetime, dictof, listof from disco.util import to_snowflake from disco.util.functional import cached_property -from disco.util.types import PreHookType, ListToDictType -from disco.types.base import BaseType from disco.types.user import User -class MessageEmbed(BaseType): +class MessageEmbed(Model): """ Message embed object @@ -23,13 +22,13 @@ class MessageEmbed(BaseType): url : str URL of the embed. """ - title = skema.StringType() - type = skema.StringType() - description = skema.StringType() - url = skema.StringType() + title = text + type = str + description = text + url = str -class MessageAttachment(BaseType): +class MessageAttachment(Model): """ Message attachment object @@ -50,16 +49,16 @@ class MessageAttachment(BaseType): width : int Width of the attachment. """ - id = skema.SnowflakeType() - filename = skema.StringType() - url = skema.StringType() - proxy_url = skema.StringType() - size = skema.IntType() - height = skema.IntType() - width = skema.IntType() + id = str + filename = text + url = str + proxy_url = str + size = int + height = int + width = int -class Message(BaseType): +class Message(Model): """ Represents a Message created within a Channel on Discord. @@ -94,26 +93,20 @@ class Message(BaseType): attachments : list(:class:`MessageAttachment`) All attachments for this message. """ - id = skema.SnowflakeType() - channel_id = skema.SnowflakeType() - - author = skema.ModelType(User) - content = skema.StringType() - nonce = skema.StringType() - - timestamp = PreHookType(lambda k: k[:-6], skema.DateTimeType()) - edited_timestamp = PreHookType(lambda k: k[:-6], skema.DateTimeType()) - - tts = skema.BooleanType() - mention_everyone = skema.BooleanType() - - pinned = skema.BooleanType(required=False) - - mentions = ListToDictType('id', skema.ModelType(User)) - mention_roles = skema.ListType(skema.SnowflakeType()) - - embeds = skema.ListType(skema.ModelType(MessageEmbed)) - attachments = ListToDictType('id', skema.ModelType(MessageAttachment)) + id = snowflake + channel_id = snowflake + author = User + content = text + nonce = snowflake + timestamp = datetime + edited_timestamp = datetime + tts = bool + mention_everyone = bool + pinned = bool + mentions = dictof(User, key='id') + mention_roles = listof(snowflake) + embeds = listof(MessageEmbed) + attachments = dictof(MessageAttachment, key='id') def __str__(self): return ''.format(self.id, self.channel_id) diff --git a/disco/types/permissions.py b/disco/types/permissions.py index 76c0ce9..ec1c39e 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -1,5 +1,3 @@ -from skema import NumberType - from holster.enum import Enum, EnumAttr Permissions = Enum( @@ -104,11 +102,6 @@ class PermissionValue(object): return cls(66060288) -class PermissionType(NumberType): - def __init__(self, *args, **kwargs): - super(PermissionType, self).__init__(number_class=PermissionValue, number_type='PermissionValue', *args, **kwargs) - - class Permissible(object): def can(self, user, *args): perms = self.get_permissions(user) diff --git a/disco/types/user.py b/disco/types/user.py index 81d0800..73f6ff7 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -1,17 +1,13 @@ -import skema +from disco.types.base import Model, snowflake, text, binary -from disco.types.base import BaseType - -class User(BaseType): - id = skema.SnowflakeType() - - username = skema.StringType() - discriminator = skema.StringType() - avatar = skema.BinaryType(None) - - verified = skema.BooleanType(required=False) - email = skema.EmailType(required=False) +class User(Model): + id = snowflake + username = text + discriminator = str + avatar = binary + verified = bool + email = str def to_string(self): return '{}#{}'.format(self.username, self.discriminator) diff --git a/disco/types/voice.py b/disco/types/voice.py index 843f04b..a73e180 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -1,20 +1,16 @@ -import skema - -from disco.types.base import BaseType - - -class VoiceState(BaseType): - session_id = skema.StringType() - - guild_id = skema.SnowflakeType() - channel_id = skema.SnowflakeType() - user_id = skema.SnowflakeType() - - deaf = skema.BooleanType() - mute = skema.BooleanType() - self_deaf = skema.BooleanType() - self_mute = skema.BooleanType() - suppress = skema.BooleanType() +from disco.types.base import Model, snowflake + + +class VoiceState(Model): + session_id = str + guild_id = snowflake + channel_id = snowflake + user_id = snowflake + deaf = bool + mute = bool + self_deaf = bool + self_mute = bool + suppress = bool @property def guild(self): diff --git a/disco/util/__init__.py b/disco/util/__init__.py index 78a14b8..2a0f667 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -1,5 +1,4 @@ import six -import skema def to_snowflake(i): @@ -11,38 +10,3 @@ def to_snowflake(i): return i.id raise Exception('{} ({}) is not convertable to a snowflake'.format(type(i), i)) - - -def _recurse(typ, field, value): - result = [] - - if isinstance(field, skema.ModelType): - result += skema_find_recursive_by_type(value, typ) - - if isinstance(field, (skema.ListType, skema.SetType, skema.DictType)): - if isinstance(field, skema.DictType): - value = value.values() - - for item in value: - if isinstance(field.field, typ): - result.append((field.field, item)) - result += _recurse(typ, field.field, item) - - return result - - -def skema_find_recursive_by_type(base, typ): - result = [] - - for name, field in base._fields_by_stored_name.items(): - v = getattr(base, name, None) - - if not v: - continue - - if isinstance(field, typ): - result.append((field, v)) - - result += _recurse(typ, field, v) - - return result diff --git a/disco/util/types.py b/disco/util/types.py deleted file mode 100644 index 46e4d39..0000000 --- a/disco/util/types.py +++ /dev/null @@ -1,36 +0,0 @@ -from skema import BaseType, DictType - - -class PreHookType(BaseType): - _hashable = False - - def __init__(self, func, field, **kwargs): - self.func = func - self.field = field - - super(PreHookType, self).__init__(**kwargs) - - def to_python(self, value): - value = self.func(value) - return self.field.to_python(value) - - def to_storage(self, *args, **kwargs): - return self.field.to_storage(*args, **kwargs) - - -class ListToDictType(DictType): - def __init__(self, key, *args, **kwargs): - super(ListToDictType, self).__init__(*args, **kwargs) - self.key = key - - def to_python(self, value): - if not value: - return {} - - to_python = self.field.to_python - - obj = {} - for item in value: - item = to_python(item) - obj[getattr(item, self.key)] = item - return obj diff --git a/requirements.txt b/requirements.txt index 6963f84..d0a50b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ holster==1.0.1 inflection==0.3.1 requests==2.11.1 six==1.10.0 -# skema==0.0.1 websocket-client==0.37.0 From a3f360535a206f32bec3ddbe320034acc3ed90ab Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 7 Oct 2016 03:36:25 -0500 Subject: [PATCH 2/3] Various fixes --- disco/gateway/events.py | 6 ++-- disco/state.py | 4 +++ disco/types/base.py | 65 ++++++++++++++++++++++++++++++++--------- 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index c882490..d8451ef 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -13,10 +13,10 @@ class GatewayEvent(Model): if not cls: raise Exception('Could not find cls for {}'.format(data['t'])) - return cls.create(data['d']) + return cls.create(data['d'], client) @classmethod - def create(cls, obj): + def create(cls, obj, client): # If this event is wrapping a model, pull its fields if hasattr(cls, '_wraps_model'): alias, model = cls._wraps_model @@ -27,7 +27,7 @@ class GatewayEvent(Model): obj[alias] = data - return cls(obj) + return cls(obj, client) def wraps_model(model, alias=None): diff --git a/disco/state.py b/disco/state.py index f32b4ba..90d2b22 100644 --- a/disco/state.py +++ b/disco/state.py @@ -154,6 +154,10 @@ class State(object): self.guilds[event.guild.id] = event.guild self.channels.update(event.guild.channels) + for channel in event.guild.channels.values(): + channel.guild_id = event.guild.id + channel.guild = event.guild + for member in event.guild.members.values(): self.users[member.user.id] = member.user diff --git a/disco/types/base.py b/disco/types/base.py index 39b2f28..0279beb 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -5,28 +5,43 @@ import functools from datetime import datetime as real_datetime +def _make(typ, data, client): + args, _, _, _ = inspect.getargspec(typ) + if 'client' in args: + return typ(data, client) + return typ(data) + + def snowflake(data): - return int(data) + return int(data) if data else None def enum(typ): def _f(data): - return typ.get(data) + return typ.get(data) if data else None return _f def listof(typ): - def _f(data): - return list(map(typ, data)) + def _f(data, client=None): + if not data: + return [] + return [_make(typ, obj, client) for obj in data] return _f def dictof(typ, key=None): - def _f(data): + def _f(data, client=None): + if not data: + return {} + if key: - return {getattr(v, key): v for v in map(typ, data)} + return { + getattr(v, key): v for v in ( + _make(typ, i, client) for i in data + )} else: - return {k: typ(v) for k, v in six.iteritems(data)} + return {k: _make(typ, v, client) for k, v in six.iteritems(data)} return _f @@ -34,16 +49,16 @@ def alias(typ, name): return ('alias', name, typ) -def datetime(typ): - return real_datetime.strptime(typ.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') +def datetime(data): + return real_datetime.strptime(data.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') if data else None def text(obj): - return six.text_type(obj) + return six.text_type(obj) if obj else six.text_type() def binary(obj): - return six.text_type(obj) + return six.text_type(obj) if obj else six.text_type() class ModelMeta(type): @@ -55,7 +70,13 @@ class ModelMeta(type): fields[v[1]] = (k, v[2]) continue - if callable(v) or inspect.isclass(v): + if inspect.isclass(v): + fields[k] = v + elif callable(v): + args, _, _, _ = inspect.getargspec(v) + if 'self' in args: + continue + fields[k] = v dct['_fields'] = fields @@ -64,6 +85,8 @@ class ModelMeta(type): class Model(six.with_metaclass(ModelMeta)): def __init__(self, obj, client=None): + self.client = client + for name, typ in self.__class__._fields.items(): dest_name = name @@ -71,10 +94,24 @@ class Model(six.with_metaclass(ModelMeta)): dest_name, typ = typ if name not in obj or not obj[name]: + if inspect.isclass(typ) and issubclass(typ, Model): + res = None + elif isinstance(typ, type): + res = typ() + else: + res = typ(None) + setattr(self, dest_name, res) continue try: - v = typ(obj[name]) + if client: + args, _, _, _ = inspect.getargspec(typ) + if 'client' in args: + v = typ(obj[name], client) + else: + v = typ(obj[name]) + else: + v = typ(obj[name]) except TypeError as e: print('Failed during parsing of field {} => {} (`{}`)'.format(name, typ, obj[name])) raise e @@ -92,7 +129,7 @@ class Model(six.with_metaclass(ModelMeta)): @classmethod def create(cls, client, data): - return cls(data) + return cls(data, client) @classmethod def create_map(cls, client, data): From 4036f21a8937a110b16e5633bb0e8adddcc2f3ac Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 7 Oct 2016 04:04:38 -0500 Subject: [PATCH 3/3] etc fixes --- disco/gateway/events.py | 4 ++++ disco/types/base.py | 30 +++++++++++++++++++++++++----- disco/types/channel.py | 7 +++++-- disco/types/message.py | 17 ++++++++++++++++- 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index d8451ef..367c9af 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -73,10 +73,12 @@ class ChannelCreate(GatewayEvent): return self.channel.guild +@wraps_model(Channel) class ChannelUpdate(ChannelCreate): pass +@wraps_model(Channel) class ChannelDelete(ChannelCreate): pass @@ -91,6 +93,7 @@ class GuildBanAdd(GatewayEvent): pass +@wraps_model(User) class GuildBanRemove(GuildBanAdd): pass @@ -144,6 +147,7 @@ class MessageCreate(GatewayEvent): return self.message.channel +@wraps_model(Message) class MessageUpdate(MessageCreate): pass diff --git a/disco/types/base.py b/disco/types/base.py index 0279beb..6ffab85 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -4,6 +4,11 @@ import functools from datetime import datetime as real_datetime +DATETIME_FORMATS = [ + '%Y-%m-%dT%H:%M:%S.%f', + '%Y-%m-%dT%H:%M:%S' +] + def _make(typ, data, client): args, _, _, _ = inspect.getargspec(typ) @@ -50,7 +55,16 @@ def alias(typ, name): def datetime(data): - return real_datetime.strptime(data.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') if data else None + if not data: + return None + + for fmt in DATETIME_FORMATS: + try: + return real_datetime.strptime(data.rsplit('+', 1)[0], fmt) + except (ValueError, TypeError): + continue + + raise ValueError('Failed to conver `{}` to datetime'.format(data)) def text(obj): @@ -112,9 +126,9 @@ class Model(six.with_metaclass(ModelMeta)): v = typ(obj[name]) else: v = typ(obj[name]) - except TypeError as e: + except Exception: print('Failed during parsing of field {} => {} (`{}`)'.format(name, typ, obj[name])) - raise e + raise if client and isinstance(v, Model): v.client = client @@ -122,15 +136,21 @@ class Model(six.with_metaclass(ModelMeta)): setattr(self, dest_name, v) def update(self, other): - for name in six.iterkeys(self.__class__.fields): + for name in six.iterkeys(self.__class__._fields): value = getattr(other, name) if value: setattr(self, name, value) + # Clear cached properties + for name in dir(type(self)): + if isinstance(getattr(type(self), name), property): + delattr(self, name) + + @classmethod def create(cls, client, data): return cls(data, client) @classmethod def create_map(cls, client, data): - return map(functools.partial(cls.create, client), data) + return list(map(functools.partial(cls.create, client), data)) diff --git a/disco/types/channel.py b/disco/types/channel.py index 030dbf9..e1beea4 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -240,7 +240,7 @@ class MessageIterator(object): self.last = None self._buffer = [] - if len(filter(bool, (before, after))) > 1: + if not before and not after: raise Exception('Must specify at most one of before or after') if not any((before, after)) and self.direction == self.Direction.DOWN: @@ -268,10 +268,13 @@ class MessageIterator(object): self._buffer.reverse() self.after == self._buffer[-1].id + def next(self): + return self.__next__() + def __iter__(self): return self - def next(self): + def __next__(self): if not len(self._buffer): self.fill() diff --git a/disco/types/message.py b/disco/types/message.py index 87d62c3..d3aa2e6 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,12 +1,24 @@ import re -from disco.types.base import Model, snowflake, text, datetime, dictof, listof +from holster.enum import Enum +from disco.types.base import Model, snowflake, text, datetime, dictof, listof, enum from disco.util import to_snowflake from disco.util.functional import cached_property from disco.types.user import User +MessageType = Enum( + DEFAULT=0, + RECIPIENT_ADD=1, + RECIPIENT_REMOVE=2, + CALL=3, + CHANNEL_NAME_CHANGE=4, + CHANNEL_ICON_CHANGE=5, + PINS_ADD=6 +) + + class MessageEmbed(Model): """ Message embed object @@ -68,6 +80,8 @@ class Message(Model): The ID of this message. channel_id : snowflake The channel ID this message was sent in. + type : ``MessageType`` + Type of the message. author : :class:`disco.types.user.User` The author of this message. content : str @@ -95,6 +109,7 @@ class Message(Model): """ id = snowflake channel_id = snowflake + type = enum(MessageType) author = User content = text nonce = snowflake