From ffe5a6f6c8a23fc09f218b2bef54f61eab893335 Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 7 Oct 2016 12:41:58 -0500 Subject: [PATCH] Refactor modeling to avoid magic as much as possible My previous stab at implementing the simple-modeling-orm-thing-tm failed in the aspect that there was a lot of duplicated code doing runtime inspection of stuff. This was due mostly to having no extra place to store information on types, making it hard to introspect how the type expected to be built, whether it had a default, etc. This commit refactors the modeling code to actually have a Field type, which wraps some information up in a simple class and allows extremely easy conversion without having to do (more) expensive runtime inspection. This also gives us the benefits of a much more readable/cleaner code, expandable field options, and not having to fuck with sphinx to get docs working correctly (it was duping attributes because they where aliases...) --- disco/gateway/events.py | 169 +++++++++++++++++++++++++++++---------- disco/types/base.py | 151 ++++++++++++++++++---------------- disco/types/channel.py | 30 +++---- disco/types/guild.py | 76 +++++++++--------- disco/types/invite.py | 20 ++--- disco/types/message.py | 54 ++++++------- disco/types/user.py | 14 ++-- disco/types/voice.py | 20 ++--- docs/api.rst | 2 +- examples/basic_plugin.py | 6 +- 10 files changed, 322 insertions(+), 220 deletions(-) diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6d809de..fd250a1 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -2,7 +2,7 @@ import inflection import six from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState -from disco.types.base import Model, snowflake, alias, listof, text +from disco.types.base import Model, Field, snowflake, listof, text # TODO: clean this... use BaseType, etc @@ -29,164 +29,251 @@ class GatewayEvent(Model): return cls(obj, client) + def __getattr__(self, name): + if hasattr(self, '_wraps_model'): + modname, _ = self._wraps_model + if hasattr(self, modname) and hasattr(getattr(self, modname), name): + return getattr(getattr(self, modname), name) + return object.__getattr__(self, name) + def wraps_model(model, alias=None): alias = alias or model.__name__.lower() def deco(cls): - cls._fields[alias] = model + cls._fields[alias] = Field(model) + cls._fields[alias].set_name(alias) cls._wraps_model = (alias, model) return cls return deco class Ready(GatewayEvent): - version = alias(int, 'v') - session_id = str - user = User - guilds = listof(Guild) + """ + Sent after the initial gateway handshake is complete. Contains data required + for bootstrapping the clients states. + """ + version = Field(int, alias='v') + session_id = Field(str) + user = Field(User) + guilds = Field(listof(Guild)) class Resumed(GatewayEvent): + """ + Sent after a resume completes. + """ pass @wraps_model(Guild) class GuildCreate(GatewayEvent): - unavailable = bool + """ + Sent when a guild is created, or becomes available. + """ + unavailable = Field(bool) @wraps_model(Guild) class GuildUpdate(GatewayEvent): + """ + Sent when a guild is updated. + """ pass class GuildDelete(GatewayEvent): - id = snowflake - unavailable = bool + """ + Sent when a guild is deleted, or becomes unavailable. + """ + id = Field(snowflake) + unavailable = Field(bool) @wraps_model(Channel) class ChannelCreate(GatewayEvent): - @property - def guild(self): - return self.channel.guild + """ + Sent when a channel is created. + """ @wraps_model(Channel) class ChannelUpdate(ChannelCreate): + """ + Sent when a channel is updated. + """ pass @wraps_model(Channel) class ChannelDelete(ChannelCreate): + """ + Sent when a channel is deleted. + """ pass class ChannelPinsUpdate(GatewayEvent): - channel_id = snowflake - last_pin_timestamp = int + """ + Sent when a channels pins are updated. + """ + channel_id = Field(snowflake) + last_pin_timestamp = Field(int) @wraps_model(User) class GuildBanAdd(GatewayEvent): + """ + Sent when a user is banned from a guild. + """ pass @wraps_model(User) class GuildBanRemove(GuildBanAdd): + """ + Sent when a user is unbanned from a guild. + """ pass class GuildEmojisUpdate(GatewayEvent): + """ + Sent when a guilds emojis are updated. + """ pass class GuildIntegrationsUpdate(GatewayEvent): + """ + Sent when a guilds integrations are updated. + """ pass class GuildMembersChunk(GatewayEvent): - guild_id = snowflake - members = listof(GuildMember) + """ + Sent in response to a members chunk request. + """ + guild_id = Field(snowflake) + members = Field(listof(GuildMember)) @wraps_model(GuildMember, alias='member') class GuildMemberAdd(GatewayEvent): + """ + Sent when a user joins a guild. + """ pass class GuildMemberRemove(GatewayEvent): - guild_id = snowflake - user = User + """ + Sent when a user leaves a guild (via leaving, kicking, or banning). + """ + guild_id = Field(snowflake) + user = Field(User) @wraps_model(GuildMember, alias='member') class GuildMemberUpdate(GatewayEvent): + """ + Sent when a guilds member is updated. + """ pass class GuildRoleCreate(GatewayEvent): - guild_id = snowflake - role = Role + """ + Sent when a role is created. + """ + guild_id = Field(snowflake) + role = Field(Role) class GuildRoleUpdate(GuildRoleCreate): + """ + Sent when a role is updated. + """ pass class GuildRoleDelete(GuildRoleCreate): + """ + Sent when a role is deleted. + """ pass @wraps_model(Message) class MessageCreate(GatewayEvent): - @property - def channel(self): - return self.message.channel + """ + Sent when a message is created. + """ @wraps_model(Message) class MessageUpdate(MessageCreate): + """ + Sent when a message is updated/edited. + """ pass class MessageDelete(GatewayEvent): - id = snowflake - channel_id = snowflake + """ + Sent when a message is deleted. + """ + id = Field(snowflake) + channel_id = Field(snowflake) class MessageDeleteBulk(GatewayEvent): - channel_id = snowflake - ids = listof(snowflake) + """ + Sent when multiple messages are deleted from a channel. + """ + channel_id = Field(snowflake) + ids = Field(listof(snowflake)) class PresenceUpdate(GatewayEvent): + """ + Sent when a users presence is updated. + """ class Game(Model): # TODO enum - type = int - name = text - url = text + type = Field(int) + name = Field(text) + url = Field(text) - user = User - guild_id = snowflake - roles = listof(snowflake) - game = Game - status = text + user = Field(User) + guild_id = Field(snowflake) + roles = Field(listof(snowflake)) + game = Field(Game) + status = Field(text) class TypingStart(GatewayEvent): - channel_id = snowflake - user_id = snowflake - timestamp = snowflake + """ + Sent when a user begins typing in a channel. + """ + channel_id = Field(snowflake) + user_id = Field(snowflake) + timestamp = Field(snowflake) @wraps_model(VoiceState, alias='state') class VoiceStateUpdate(GatewayEvent): + """ + Sent when a users voice state changes. + """ pass class VoiceServerUpdate(GatewayEvent): - token = str - endpoint = str - guild_id = snowflake + """ + Sent when a voice server is updated. + """ + token = Field(str) + endpoint = Field(str) + guild_id = Field(snowflake) diff --git a/disco/types/base.py b/disco/types/base.py index fc41614..53f7fe0 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -10,6 +10,69 @@ DATETIME_FORMATS = [ ] +class FieldType(object): + def __init__(self, typ): + if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model): + self.typ = typ + else: + self.typ = lambda raw, _: typ(raw) + + def try_convert(self, raw, client): + pass + + def __call__(self, raw, client): + return self.try_convert(raw, client) + + +class Field(FieldType): + def __init__(self, typ, alias=None): + super(Field, self).__init__(typ) + + # Set names + self.src_name = alias + self.dst_name = None + + self.default = None + + if isinstance(self.typ, FieldType): + self.default = self.typ.default + + def set_name(self, name): + if not self.dst_name: + self.dst_name = name + + if not self.src_name: + self.src_name = name + + def has_default(self): + return self.default is not None + + def try_convert(self, raw, client): + return self.typ(raw, client) + + +class _Dict(FieldType): + default = dict + + def __init__(self, typ, key=None): + super(_Dict, self).__init__(typ) + self.key = key + + def try_convert(self, raw, client): + if self.key: + converted = [self.typ(i, client) for i in raw] + return {getattr(i, self.key): i for i in converted} + else: + return {k: self.typ(v, client) for k, v in six.iteritems(raw)} + + +class _List(FieldType): + default = list + + def try_convert(self, raw, client): + return [self.typ(i, client) for i in raw] + + def _make(typ, data, client): if inspect.isclass(typ) and issubclass(typ, Model): return typ(data, client) @@ -26,33 +89,12 @@ def enum(typ): return _f -def listof(typ): - def _f(data, client=None): - if not data: - return [] - return [_make(typ, obj, client) for obj in data] - _f._takes_client = None - return _f +def listof(*args, **kwargs): + return _List(*args, **kwargs) -def dictof(typ, key=None): - def _f(data, client=None): - if not data: - return {} - - if key: - return { - getattr(v, key): v for v in ( - _make(typ, i, client) for i in data - )} - else: - return {k: _make(typ, v, client) for k, v in six.iteritems(data)} - _f._takes_client = None - return _f - - -def alias(typ, name): - return ('alias', name, typ) +def dictof(*args, **kwargs): + return _Dict(*args, **kwargs) def datetime(data): @@ -76,23 +118,21 @@ def binary(obj): return six.text_type(obj) if obj else six.text_type() +def field(typ, alias=None): + pass + + class ModelMeta(type): def __new__(cls, name, parents, dct): fields = {} - for k, v in six.iteritems(dct): - if isinstance(v, tuple): - if v[0] == 'alias': - fields[v[1]] = (k, v[2]) - continue - if inspect.isclass(v): - fields[k] = v - elif callable(v): - args, _, _, _ = inspect.getargspec(v) - if 'self' in args: - continue + for k, v in six.iteritems(dct): + if not isinstance(v, Field): + continue - fields[k] = v + v.set_name(k) + fields[k] = v + dct[k] = None dct['_fields'] = fields return super(ModelMeta, cls).__new__(cls, name, parents, dct) @@ -102,40 +142,17 @@ class Model(six.with_metaclass(ModelMeta)): def __init__(self, obj, client=None): self.client = client - for name, typ in self.__class__._fields.items(): - dest_name = name - - if isinstance(typ, tuple): - dest_name, typ = typ - - if name not in obj or not obj[name]: - if inspect.isclass(typ) and issubclass(typ, Model): - res = None - elif isinstance(typ, type): - res = typ() - else: - res = typ(None) - setattr(self, dest_name, res) + for name, field in self._fields.items(): + if name not in obj or not obj[field.src_name]: + if field.has_default(): + setattr(self, field.dst_name, field.default()) continue - try: - if client: - if inspect.isfunction(typ) and hasattr(typ, '_takes_client'): - v = typ(obj[name], client) - elif inspect.isclass(typ) and issubclass(typ, Model): - v = typ(obj[name], client) - else: - v = typ(obj[name]) - else: - v = typ(obj[name]) - except Exception: - print('Failed during parsing of field {} => {}'.format(name, typ)) - raise - - setattr(self, dest_name, v) + value = field.try_convert(obj[field.src_name], client) + setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self.__class__._fields): + for name in six.iterkeys(self._fields): value = getattr(other, name) if value: setattr(self, name, value) diff --git a/disco/types/channel.py b/disco/types/channel.py index e1beea4..0e5be65 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,6 +1,6 @@ from holster.enum import Enum -from disco.types.base import Model, snowflake, enum, listof, dictof, alias, text +from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text from disco.types.permissions import PermissionValue from disco.util.functional import cached_property @@ -39,10 +39,10 @@ class PermissionOverwrite(Model): All denied permissions """ - id = snowflake - type = enum(PermissionOverwriteType) - allow = PermissionValue - deny = PermissionValue + id = Field(snowflake) + type = Field(enum(PermissionOverwriteType)) + allow = Field(PermissionValue) + deny = Field(PermissionValue) class Channel(Model, Permissible): @@ -70,16 +70,16 @@ class Channel(Model, Permissible): overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`) Channel permissions overwrites. """ - id = snowflake - guild_id = snowflake - name = text - topic = text - _last_message_id = alias(snowflake, 'last_message_id') - position = int - bitrate = int - recipients = listof(User) - type = enum(ChannelType) - overwrites = alias(dictof(PermissionOverwrite, key='id'), 'permission_overwrites') + id = Field(snowflake) + guild_id = Field(snowflake) + name = Field(text) + topic = Field(text) + _last_message_id = Field(snowflake, alias='last_message_id') + position = Field(int) + bitrate = Field(int) + recipients = Field(listof(User)) + type = Field(enum(ChannelType)) + overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') def get_permissions(self, user): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index 02d3c0d..10d93eb 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -3,7 +3,7 @@ from holster.enum import Enum from disco.api.http import APIException from disco.util import to_snowflake from disco.util.functional import cached_property -from disco.types.base import Model, snowflake, listof, dictof, datetime, text, binary, enum +from disco.types.base import Model, Field, snowflake, listof, dictof, datetime, text, binary, enum from disco.types.user import User from disco.types.voice import VoiceState from disco.types.permissions import PermissionValue, Permissions, Permissible @@ -36,11 +36,11 @@ class Emoji(Model): roles : list(snowflake) Roles this emoji is attached to. """ - id = snowflake - name = text - require_colons = bool - managed = bool - roles = listof(snowflake) + id = Field(snowflake) + name = Field(text) + require_colons = Field(bool) + managed = Field(bool) + roles = Field(listof(snowflake)) class Role(Model): @@ -64,13 +64,13 @@ class Role(Model): position : int The position of this role in the hierarchy. """ - id = snowflake - name = text - hoist = bool - managed = bool - color = int - permissions = PermissionValue - position = int + id = Field(snowflake) + name = Field(text) + hoist = Field(bool) + managed = Field(bool) + color = Field(int) + permissions = Field(PermissionValue) + position = Field(int) class GuildMember(Model): @@ -94,13 +94,13 @@ class GuildMember(Model): roles : list(snowflake) Roles this member is part of. """ - user = User - guild_id = snowflake - nick = text - mute = bool - deaf = bool - joined_at = datetime - roles = listof(snowflake) + user = Field(User) + guild_id = Field(snowflake) + nick = Field(text) + mute = Field(bool) + deaf = Field(bool) + joined_at = Field(datetime) + roles = Field(listof(snowflake)) def get_voice_state(self): """ @@ -196,24 +196,24 @@ class Guild(Model, Permissible): All of the guilds voice states. """ - id = snowflake - owner_id = snowflake - afk_channel_id = snowflake - embed_channel_id = snowflake - name = text - icon = binary - splash = binary - region = str - afk_timeout = int - embed_enabled = bool - verification_level = enum(VerificationLevel) - mfa_level = int - features = listof(str) - members = dictof(GuildMember, key='id') - channels = dictof(Channel, key='id') - roles = dictof(Role, key='id') - emojis = dictof(Emoji, key='id') - voice_states = dictof(VoiceState, key='session_id') + id = Field(snowflake) + owner_id = Field(snowflake) + afk_channel_id = Field(snowflake) + embed_channel_id = Field(snowflake) + name = Field(text) + icon = Field(binary) + splash = Field(binary) + region = Field(str) + afk_timeout = Field(int) + embed_enabled = Field(bool) + verification_level = Field(enum(VerificationLevel)) + mfa_level = Field(int) + features = Field(listof(str)) + members = Field(dictof(GuildMember, key='id')) + channels = Field(dictof(Channel, key='id')) + roles = Field(dictof(Role, key='id')) + emojis = Field(dictof(Emoji, key='id')) + voice_states = Field(dictof(VoiceState, key='session_id')) def get_permissions(self, user): """ diff --git a/disco/types/invite.py b/disco/types/invite.py index f8fbc43..2bfb355 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,4 +1,4 @@ -from disco.types.base import Model, datetime +from disco.types.base import Model, Field, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel @@ -29,12 +29,12 @@ class Invite(Model): created_at : datetime When this invite was created. """ - code = str - inviter = User - guild = Guild - channel = Channel - max_age = int - max_uses = int - uses = int - temporary = bool - created_at = datetime + code = Field(str) + inviter = Field(User) + guild = Field(Guild) + channel = Field(Channel) + max_age = Field(int) + max_uses = Field(int) + uses = Field(int) + temporary = Field(bool) + created_at = Field(datetime) diff --git a/disco/types/message.py b/disco/types/message.py index de6d96e..ba88328 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -2,7 +2,7 @@ import re from holster.enum import Enum -from disco.types.base import Model, snowflake, text, datetime, dictof, listof, enum +from disco.types.base import Model, Field, snowflake, text, datetime, dictof, listof, enum from disco.util import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -34,10 +34,10 @@ class MessageEmbed(Model): url : str URL of the embed. """ - title = text - type = str - description = text - url = str + title = Field(text) + type = Field(str) + description = Field(text) + url = Field(str) class MessageAttachment(Model): @@ -61,13 +61,13 @@ class MessageAttachment(Model): width : int Width of the attachment. """ - id = str - filename = text - url = str - proxy_url = str - size = int - height = int - width = int + id = Field(str) + filename = Field(text) + url = Field(str) + proxy_url = Field(str) + size = Field(int) + height = Field(int) + width = Field(int) class Message(Model): @@ -107,21 +107,21 @@ class Message(Model): attachments : list(:class:`MessageAttachment`) All attachments for this message. """ - id = snowflake - channel_id = snowflake - type = enum(MessageType) - author = User - content = text - nonce = snowflake - timestamp = datetime - edited_timestamp = datetime - tts = bool - mention_everyone = bool - pinned = bool - mentions = dictof(User, key='id') - mention_roles = listof(snowflake) - embeds = listof(MessageEmbed) - attachments = dictof(MessageAttachment, key='id') + id = Field(snowflake) + channel_id = Field(snowflake) + type = Field(enum(MessageType)) + author = Field(User) + content = Field(text) + nonce = Field(snowflake) + timestamp = Field(datetime) + edited_timestamp = Field(datetime) + tts = Field(bool) + mention_everyone = Field(bool) + pinned = Field(bool) + mentions = Field(dictof(User, key='id')) + mention_roles = Field(listof(snowflake)) + embeds = Field(listof(MessageEmbed)) + attachments = Field(dictof(MessageAttachment, key='id')) def __str__(self): return ''.format(self.id, self.channel_id) diff --git a/disco/types/user.py b/disco/types/user.py index 73f6ff7..160f983 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -1,13 +1,13 @@ -from disco.types.base import Model, snowflake, text, binary +from disco.types.base import Model, Field, snowflake, text, binary class User(Model): - id = snowflake - username = text - discriminator = str - avatar = binary - verified = bool - email = str + id = Field(snowflake) + username = Field(text) + discriminator = Field(str) + avatar = Field(binary) + verified = Field(bool) + email = Field(str) def to_string(self): return '{}#{}'.format(self.username, self.discriminator) diff --git a/disco/types/voice.py b/disco/types/voice.py index a73e180..aadfa58 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -1,16 +1,16 @@ -from disco.types.base import Model, snowflake +from disco.types.base import Model, Field, snowflake class VoiceState(Model): - session_id = str - guild_id = snowflake - channel_id = snowflake - user_id = snowflake - deaf = bool - mute = bool - self_deaf = bool - self_mute = bool - suppress = bool + session_id = Field(str) + guild_id = Field(snowflake) + channel_id = Field(snowflake) + user_id = Field(snowflake) + deaf = Field(bool) + mute = Field(bool) + self_deaf = Field(bool) + self_mute = Field(bool) + suppress = Field(bool) @property def guild(self): diff --git a/docs/api.rst b/docs/api.rst index 808483d..c090248 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -118,7 +118,7 @@ GatewayClient Gateway Events ~~~~~~~~~~~~~~ -.. automodule:: disco.gateway.client.Events +.. automodule:: disco.gateway.events :members: diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 3bc1746..8796745 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -10,10 +10,8 @@ from disco.types.permissions import Permissions class BasicPlugin(Plugin): @Plugin.listen('MessageCreate') - def on_message_create(self, event): - self.log.info('Message created: <{}>: {}'.format( - event.message.author.username, - event.message.content)) + def on_message_create(self, msg): + self.log.info('Message created: {}: {}'.format(msg.author, msg.content)) @Plugin.command('status', '[component]') def on_status_command(self, event, component=None):