diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 5d1314d..c882490 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -1,25 +1,19 @@ import inflection -import skema import six -from disco.util import skema_find_recursive_by_type from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState +from disco.types.base import Model, snowflake, alias, listof # TODO: clean this... use BaseType, etc -class GatewayEvent(skema.Model): +class GatewayEvent(Model): @staticmethod def from_dispatch(client, data): cls = globals().get(inflection.camelize(data['t'].lower())) if not cls: raise Exception('Could not find cls for {}'.format(data['t'])) - obj = cls.create(data['d']) - - for field, value in skema_find_recursive_by_type(obj, skema.ModelType): - value.client = client - - return obj + return cls.create(data['d']) @classmethod def create(cls, obj): @@ -28,31 +22,29 @@ class GatewayEvent(skema.Model): alias, model = cls._wraps_model data = { - k: obj.pop(k) for k in six.iterkeys(model._fields_by_stored_name) if k in obj + k: obj.pop(k) for k in six.iterkeys(model._fields) if k in obj } obj[alias] = data - self = cls(obj) - self.validate() - return self + return cls(obj) def wraps_model(model, alias=None): alias = alias or model.__name__.lower() def deco(cls): - cls.add_field(alias, skema.ModelType(model)) + cls._fields[alias] = model cls._wraps_model = (alias, model) return cls return deco class Ready(GatewayEvent): - version = skema.IntType(stored_name='v') - session_id = skema.StringType() - user = skema.ModelType(User) - guilds = skema.ListType(skema.ModelType(Guild)) + version = alias(int, 'v') + session_id = str + user = User + guilds = listof(Guild) class Resumed(GatewayEvent): @@ -61,17 +53,17 @@ class Resumed(GatewayEvent): @wraps_model(Guild) class GuildCreate(GatewayEvent): - unavailable = skema.BooleanType(default=None) + unavailable = bool @wraps_model(Guild) class GuildUpdate(GatewayEvent): - guild = skema.ModelType(Guild) + pass class GuildDelete(GatewayEvent): - id = skema.SnowflakeType() - unavailable = skema.BooleanType(default=None) + id = snowflake + unavailable = bool @wraps_model(Channel) @@ -90,8 +82,8 @@ class ChannelDelete(ChannelCreate): class ChannelPinsUpdate(GatewayEvent): - channel_id = skema.SnowflakeType() - last_pin_timestamp = skema.IntType() + channel_id = snowflake + last_pin_timestamp = int @wraps_model(User) @@ -112,8 +104,8 @@ class GuildIntegrationsUpdate(GatewayEvent): class GuildMembersChunk(GatewayEvent): - guild_id = skema.SnowflakeType() - members = skema.ListType(skema.ModelType(GuildMember)) + guild_id = snowflake + members = listof(GuildMember) @wraps_model(GuildMember, alias='member') @@ -122,29 +114,27 @@ class GuildMemberAdd(GatewayEvent): class GuildMemberRemove(GatewayEvent): - guild_id = skema.SnowflakeType() - user = skema.ModelType(User) + guild_id = snowflake + user = User class GuildMemberUpdate(GatewayEvent): - guild_id = skema.SnowflakeType() - user = skema.ModelType(User) - roles = skema.ListType(skema.SnowflakeType()) + guild_id = snowflake + user = User + roles = listof(snowflake) class GuildRoleCreate(GatewayEvent): - guild_id = skema.SnowflakeType() - role = skema.ModelType(Role) + guild_id = snowflake + role = Role -class GuildRoleUpdate(GatewayEvent): - guild_id = skema.SnowflakeType() - role = skema.ModelType(Role) +class GuildRoleUpdate(GuildRoleCreate): + pass -class GuildRoleDelete(GatewayEvent): - guild_id = skema.SnowflakeType() - role = skema.ModelType(Role) +class GuildRoleDelete(GuildRoleCreate): + pass @wraps_model(Message) @@ -159,32 +149,33 @@ class MessageUpdate(MessageCreate): class MessageDelete(GatewayEvent): - id = skema.SnowflakeType() - channel_id = skema.SnowflakeType() + id = snowflake + channel_id = snowflake class MessageDeleteBulk(GatewayEvent): - channel_id = skema.SnowflakeType() - ids = skema.ListType(skema.SnowflakeType()) + channel_id = snowflake + ids = listof(snowflake) class PresenceUpdate(GatewayEvent): - class Game(skema.Model): - type = skema.IntType() - name = skema.StringType() - url = skema.StringType(required=False) + class Game(Model): + # TODO enum + type = int + name = str + url = str - user = skema.ModelType(User) - guild_id = skema.SnowflakeType() - roles = skema.ListType(skema.SnowflakeType()) - game = skema.ModelType(Game) - status = skema.StringType() + user = User + guild_id = snowflake + roles = listof(snowflake) + game = Game + status = str class TypingStart(GatewayEvent): - channel_id = skema.SnowflakeType() - user_id = skema.SnowflakeType() - timestamp = skema.IntType() + channel_id = snowflake + user_id = snowflake + timestamp = snowflake @wraps_model(VoiceState, alias='state') @@ -193,6 +184,6 @@ class VoiceStateUpdate(GatewayEvent): class VoiceServerUpdate(GatewayEvent): - token = skema.StringType() - endpoint = skema.StringType() - guild_id = skema.SnowflakeType() + token = str + endpoint = str + guild_id = snowflake diff --git a/disco/types/base.py b/disco/types/base.py index 2f9c41c..39b2f28 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -1,29 +1,98 @@ -import skema +import six +import inspect import functools -from disco.util import skema_find_recursive_by_type -# from disco.util.types import DeferredModel +from datetime import datetime as real_datetime -class BaseType(skema.Model): +def snowflake(data): + return int(data) + + +def enum(typ): + def _f(data): + return typ.get(data) + return _f + + +def listof(typ): + def _f(data): + return list(map(typ, data)) + return _f + + +def dictof(typ, key=None): + def _f(data): + if key: + return {getattr(v, key): v for v in map(typ, data)} + else: + return {k: typ(v) for k, v in six.iteritems(data)} + return _f + + +def alias(typ, name): + return ('alias', name, typ) + + +def datetime(typ): + return real_datetime.strptime(typ.rsplit('+', 1)[0], '%Y-%m-%dT%H:%M:%S.%f') + + +def text(obj): + return six.text_type(obj) + + +def binary(obj): + return six.text_type(obj) + + +class ModelMeta(type): + def __new__(cls, name, parents, dct): + fields = {} + for k, v in six.iteritems(dct): + if isinstance(v, tuple): + if v[0] == 'alias': + fields[v[1]] = (k, v[2]) + continue + + if callable(v) or inspect.isclass(v): + fields[k] = v + + dct['_fields'] = fields + return super(ModelMeta, cls).__new__(cls, name, parents, dct) + + +class Model(six.with_metaclass(ModelMeta)): + def __init__(self, obj, client=None): + for name, typ in self.__class__._fields.items(): + dest_name = name + + if isinstance(typ, tuple): + dest_name, typ = typ + + if name not in obj or not obj[name]: + continue + + try: + v = typ(obj[name]) + except TypeError as e: + print('Failed during parsing of field {} => {} (`{}`)'.format(name, typ, obj[name])) + raise e + + if client and isinstance(v, Model): + v.client = client + + setattr(self, dest_name, v) + def update(self, other): - for name, field in other.__class__._fields.items(): + for name in six.iterkeys(self.__class__.fields): value = getattr(other, name) if value: setattr(self, name, value) @classmethod def create(cls, client, data): - obj = cls(data) - - # Valdiate - obj.validate() - - for field, value in skema_find_recursive_by_type(obj, skema.ModelType): - value.client = client - - obj.client = client - return obj + return cls(data) @classmethod def create_map(cls, client, data): diff --git a/disco/types/channel.py b/disco/types/channel.py index 6ef1eb1..030dbf9 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,12 +1,11 @@ -import skema - from holster.enum import Enum +from disco.types.base import Model, snowflake, enum, listof, dictof, alias, text +from disco.types.permissions import PermissionValue + from disco.util.functional import cached_property -from disco.util.types import ListToDictType -from disco.types.base import BaseType from disco.types.user import User -from disco.types.permissions import PermissionType, Permissions, Permissible +from disco.types.permissions import Permissions, Permissible from disco.voice.client import VoiceClient @@ -23,7 +22,7 @@ PermissionOverwriteType = Enum( ) -class PermissionOverwrite(BaseType): +class PermissionOverwrite(Model): """ A PermissionOverwrite for a :class:`Channel` @@ -39,14 +38,14 @@ class PermissionOverwrite(BaseType): denied : :class:`PermissionValue` All denied permissions """ - id = skema.SnowflakeType() - type = skema.StringType(choices=PermissionOverwriteType.ALL_VALUES) - allow = PermissionType() - deny = PermissionType() + id = snowflake + type = enum(PermissionOverwriteType) + allow = PermissionValue + deny = PermissionValue -class Channel(BaseType, Permissible): +class Channel(Model, Permissible): """ Represents a Discord Channel @@ -71,19 +70,16 @@ class Channel(BaseType, Permissible): overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`) Channel permissions overwrites. """ - id = skema.SnowflakeType() - guild_id = skema.SnowflakeType(required=False) - - name = skema.StringType() - topic = skema.StringType() - _last_message_id = skema.SnowflakeType(stored_name='last_message_id') - position = skema.IntType() - bitrate = skema.IntType(required=False) - - recipients = skema.ListType(skema.ModelType(User)) - type = skema.IntType(choices=ChannelType.ALL_VALUES) - - overwrites = ListToDictType('id', skema.ModelType(PermissionOverwrite), stored_name='permission_overwrites') + id = snowflake + guild_id = snowflake + name = text + topic = text + _last_message_id = alias(snowflake, 'last_message_id') + position = int + bitrate = int + recipients = listof(User) + type = enum(ChannelType) + overwrites = alias(dictof(PermissionOverwrite, key='id'), 'permission_overwrites') def get_permissions(self, user): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index dd437fe..9cc6943 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,17 +1,15 @@ -import skema -import copy + +from disco.types.base import Model, snowflake, listof, dictof, datetime, text, binary from disco.api.http import APIException from disco.util import to_snowflake -from disco.util.types import PreHookType, ListToDictType -from disco.types.base import BaseType from disco.types.user import User from disco.types.voice import VoiceState -from disco.types.permissions import PermissionType, PermissionValue, Permissions, Permissible +from disco.types.permissions import PermissionValue, Permissions, Permissible from disco.types.channel import Channel -class Emoji(BaseType): +class Emoji(Model): """ An emoji object @@ -28,14 +26,14 @@ class Emoji(BaseType): roles : list(snowflake) Roles this emoji is attached to. """ - id = skema.SnowflakeType() - name = skema.StringType() - require_colons = skema.BooleanType() - managed = skema.BooleanType() - roles = skema.ListType(skema.SnowflakeType()) + id = snowflake + name = text + require_colons = bool + managed = bool + roles = listof(snowflake) -class Role(BaseType): +class Role(Model): """ A role object @@ -56,16 +54,16 @@ class Role(BaseType): position : int The position of this role in the hierarchy. """ - id = skema.SnowflakeType() - name = skema.StringType() - hoist = skema.BooleanType() - managed = skema.BooleanType() - color = skema.IntType() - permissions = PermissionType() - position = skema.IntType() + id = snowflake + name = text + hoist = bool + managed = bool + color = int + permissions = PermissionValue + position = int -class GuildMember(BaseType): +class GuildMember(Model): """ A GuildMember object @@ -84,12 +82,12 @@ class GuildMember(BaseType): roles : list(snowflake) Roles this member is part of. """ - user = skema.ModelType(User) - guild_id = skema.SnowflakeType(required=False) - mute = skema.BooleanType() - deaf = skema.BooleanType() - joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) - roles = skema.ListType(skema.SnowflakeType()) + user = User + guild_id = snowflake + mute = bool + deaf = bool + joined_at = datetime + roles = listof(snowflake) def get_voice_state(self): """ @@ -126,7 +124,7 @@ class GuildMember(BaseType): return self.user.id -class Guild(BaseType, Permissible): +class Guild(Model, Permissible): """ A guild object @@ -170,29 +168,24 @@ class Guild(BaseType, Permissible): All of the guilds voice states. """ - id = skema.SnowflakeType() - - owner_id = skema.SnowflakeType() - afk_channel_id = skema.SnowflakeType() - embed_channel_id = skema.SnowflakeType() - - name = skema.StringType() - icon = skema.BinaryType(None) - splash = skema.BinaryType(None) - region = skema.StringType() - - afk_timeout = skema.IntType() - embed_enabled = skema.BooleanType() - verification_level = skema.IntType() - mfa_level = skema.IntType() - - features = skema.ListType(skema.StringType()) - - members = ListToDictType('id', skema.ModelType(copy.deepcopy(GuildMember))) - channels = ListToDictType('id', skema.ModelType(Channel)) - roles = ListToDictType('id', skema.ModelType(Role)) - emojis = ListToDictType('id', skema.ModelType(Emoji)) - voice_states = ListToDictType('session_id', skema.ModelType(VoiceState)) + id = snowflake + owner_id = snowflake + afk_channel_id = snowflake + embed_channel_id = snowflake + name = text + icon = binary + splash = binary + region = str + afk_timeout = int + embed_enabled = bool + verification_level = int + mfa_level = int + features = listof(str) + members = dictof(GuildMember, key='id') + channels = dictof(Channel, key='id') + roles = dictof(Role, key='id') + emojis = dictof(Emoji, key='id') + voice_states = dictof(VoiceState, key='session_id') def get_permissions(self, user): """ diff --git a/disco/types/invite.py b/disco/types/invite.py index cb7a737..f8fbc43 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,13 +1,10 @@ -import skema - -from disco.util.types import PreHookType -from disco.types.base import BaseType +from disco.types.base import Model, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel -class Invite(BaseType): +class Invite(Model): """ An invite object @@ -32,15 +29,12 @@ class Invite(BaseType): created_at : datetime When this invite was created. """ - code = skema.StringType() - - inviter = skema.ModelType(User) - guild = skema.ModelType(Guild) - channel = skema.ModelType(Channel) - - max_age = skema.IntType() - max_uses = skema.IntType() - uses = skema.IntType() - temporary = skema.BooleanType() - - created_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) + code = str + inviter = User + guild = Guild + channel = Channel + max_age = int + max_uses = int + uses = int + temporary = bool + created_at = datetime diff --git a/disco/types/message.py b/disco/types/message.py index 4f171be..87d62c3 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -1,14 +1,13 @@ import re -import skema + +from disco.types.base import Model, snowflake, text, datetime, dictof, listof from disco.util import to_snowflake from disco.util.functional import cached_property -from disco.util.types import PreHookType, ListToDictType -from disco.types.base import BaseType from disco.types.user import User -class MessageEmbed(BaseType): +class MessageEmbed(Model): """ Message embed object @@ -23,13 +22,13 @@ class MessageEmbed(BaseType): url : str URL of the embed. """ - title = skema.StringType() - type = skema.StringType() - description = skema.StringType() - url = skema.StringType() + title = text + type = str + description = text + url = str -class MessageAttachment(BaseType): +class MessageAttachment(Model): """ Message attachment object @@ -50,16 +49,16 @@ class MessageAttachment(BaseType): width : int Width of the attachment. """ - id = skema.SnowflakeType() - filename = skema.StringType() - url = skema.StringType() - proxy_url = skema.StringType() - size = skema.IntType() - height = skema.IntType() - width = skema.IntType() + id = str + filename = text + url = str + proxy_url = str + size = int + height = int + width = int -class Message(BaseType): +class Message(Model): """ Represents a Message created within a Channel on Discord. @@ -94,26 +93,20 @@ class Message(BaseType): attachments : list(:class:`MessageAttachment`) All attachments for this message. """ - id = skema.SnowflakeType() - channel_id = skema.SnowflakeType() - - author = skema.ModelType(User) - content = skema.StringType() - nonce = skema.StringType() - - timestamp = PreHookType(lambda k: k[:-6], skema.DateTimeType()) - edited_timestamp = PreHookType(lambda k: k[:-6], skema.DateTimeType()) - - tts = skema.BooleanType() - mention_everyone = skema.BooleanType() - - pinned = skema.BooleanType(required=False) - - mentions = ListToDictType('id', skema.ModelType(User)) - mention_roles = skema.ListType(skema.SnowflakeType()) - - embeds = skema.ListType(skema.ModelType(MessageEmbed)) - attachments = ListToDictType('id', skema.ModelType(MessageAttachment)) + id = snowflake + channel_id = snowflake + author = User + content = text + nonce = snowflake + timestamp = datetime + edited_timestamp = datetime + tts = bool + mention_everyone = bool + pinned = bool + mentions = dictof(User, key='id') + mention_roles = listof(snowflake) + embeds = listof(MessageEmbed) + attachments = dictof(MessageAttachment, key='id') def __str__(self): return ''.format(self.id, self.channel_id) diff --git a/disco/types/permissions.py b/disco/types/permissions.py index 76c0ce9..ec1c39e 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -1,5 +1,3 @@ -from skema import NumberType - from holster.enum import Enum, EnumAttr Permissions = Enum( @@ -104,11 +102,6 @@ class PermissionValue(object): return cls(66060288) -class PermissionType(NumberType): - def __init__(self, *args, **kwargs): - super(PermissionType, self).__init__(number_class=PermissionValue, number_type='PermissionValue', *args, **kwargs) - - class Permissible(object): def can(self, user, *args): perms = self.get_permissions(user) diff --git a/disco/types/user.py b/disco/types/user.py index 81d0800..73f6ff7 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -1,17 +1,13 @@ -import skema +from disco.types.base import Model, snowflake, text, binary -from disco.types.base import BaseType - -class User(BaseType): - id = skema.SnowflakeType() - - username = skema.StringType() - discriminator = skema.StringType() - avatar = skema.BinaryType(None) - - verified = skema.BooleanType(required=False) - email = skema.EmailType(required=False) +class User(Model): + id = snowflake + username = text + discriminator = str + avatar = binary + verified = bool + email = str def to_string(self): return '{}#{}'.format(self.username, self.discriminator) diff --git a/disco/types/voice.py b/disco/types/voice.py index 843f04b..a73e180 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -1,20 +1,16 @@ -import skema - -from disco.types.base import BaseType - - -class VoiceState(BaseType): - session_id = skema.StringType() - - guild_id = skema.SnowflakeType() - channel_id = skema.SnowflakeType() - user_id = skema.SnowflakeType() - - deaf = skema.BooleanType() - mute = skema.BooleanType() - self_deaf = skema.BooleanType() - self_mute = skema.BooleanType() - suppress = skema.BooleanType() +from disco.types.base import Model, snowflake + + +class VoiceState(Model): + session_id = str + guild_id = snowflake + channel_id = snowflake + user_id = snowflake + deaf = bool + mute = bool + self_deaf = bool + self_mute = bool + suppress = bool @property def guild(self): diff --git a/disco/util/__init__.py b/disco/util/__init__.py index 78a14b8..2a0f667 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -1,5 +1,4 @@ import six -import skema def to_snowflake(i): @@ -11,38 +10,3 @@ def to_snowflake(i): return i.id raise Exception('{} ({}) is not convertable to a snowflake'.format(type(i), i)) - - -def _recurse(typ, field, value): - result = [] - - if isinstance(field, skema.ModelType): - result += skema_find_recursive_by_type(value, typ) - - if isinstance(field, (skema.ListType, skema.SetType, skema.DictType)): - if isinstance(field, skema.DictType): - value = value.values() - - for item in value: - if isinstance(field.field, typ): - result.append((field.field, item)) - result += _recurse(typ, field.field, item) - - return result - - -def skema_find_recursive_by_type(base, typ): - result = [] - - for name, field in base._fields_by_stored_name.items(): - v = getattr(base, name, None) - - if not v: - continue - - if isinstance(field, typ): - result.append((field, v)) - - result += _recurse(typ, field, v) - - return result diff --git a/disco/util/types.py b/disco/util/types.py deleted file mode 100644 index 46e4d39..0000000 --- a/disco/util/types.py +++ /dev/null @@ -1,36 +0,0 @@ -from skema import BaseType, DictType - - -class PreHookType(BaseType): - _hashable = False - - def __init__(self, func, field, **kwargs): - self.func = func - self.field = field - - super(PreHookType, self).__init__(**kwargs) - - def to_python(self, value): - value = self.func(value) - return self.field.to_python(value) - - def to_storage(self, *args, **kwargs): - return self.field.to_storage(*args, **kwargs) - - -class ListToDictType(DictType): - def __init__(self, key, *args, **kwargs): - super(ListToDictType, self).__init__(*args, **kwargs) - self.key = key - - def to_python(self, value): - if not value: - return {} - - to_python = self.field.to_python - - obj = {} - for item in value: - item = to_python(item) - obj[getattr(item, self.key)] = item - return obj diff --git a/requirements.txt b/requirements.txt index 6963f84..d0a50b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ holster==1.0.1 inflection==0.3.1 requests==2.11.1 six==1.10.0 -# skema==0.0.1 websocket-client==0.37.0