Browse Source

Refactor modeling to avoid magic as much as possible

My previous stab at implementing the simple-modeling-orm-thing-tm failed
in the aspect that there was a lot of duplicated code doing runtime
inspection of stuff. This was due mostly to having no extra place to
store information on types, making it hard to introspect how the type
expected to be built, whether it had a default, etc.

This commit refactors the modeling code to actually have a Field type,
which wraps some information up in a simple class and allows extremely
easy conversion without having to do (more) expensive runtime
inspection. This also gives us the benefits of a much more
readable/cleaner code, expandable field options, and not having to fuck
with sphinx to get docs working correctly (it was duping attributes
because they where aliases...)
pull/5/head
Andrei 9 years ago
parent
commit
ffe5a6f6c8
  1. 169
      disco/gateway/events.py
  2. 151
      disco/types/base.py
  3. 30
      disco/types/channel.py
  4. 76
      disco/types/guild.py
  5. 20
      disco/types/invite.py
  6. 54
      disco/types/message.py
  7. 14
      disco/types/user.py
  8. 20
      disco/types/voice.py
  9. 2
      docs/api.rst
  10. 6
      examples/basic_plugin.py

169
disco/gateway/events.py

@ -2,7 +2,7 @@ import inflection
import six import six
from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState
from disco.types.base import Model, snowflake, alias, listof, text from disco.types.base import Model, Field, snowflake, listof, text
# TODO: clean this... use BaseType, etc # TODO: clean this... use BaseType, etc
@ -29,164 +29,251 @@ class GatewayEvent(Model):
return cls(obj, client) return cls(obj, client)
def __getattr__(self, name):
if hasattr(self, '_wraps_model'):
modname, _ = self._wraps_model
if hasattr(self, modname) and hasattr(getattr(self, modname), name):
return getattr(getattr(self, modname), name)
return object.__getattr__(self, name)
def wraps_model(model, alias=None): def wraps_model(model, alias=None):
alias = alias or model.__name__.lower() alias = alias or model.__name__.lower()
def deco(cls): def deco(cls):
cls._fields[alias] = model cls._fields[alias] = Field(model)
cls._fields[alias].set_name(alias)
cls._wraps_model = (alias, model) cls._wraps_model = (alias, model)
return cls return cls
return deco return deco
class Ready(GatewayEvent): class Ready(GatewayEvent):
version = alias(int, 'v') """
session_id = str Sent after the initial gateway handshake is complete. Contains data required
user = User for bootstrapping the clients states.
guilds = listof(Guild) """
version = Field(int, alias='v')
session_id = Field(str)
user = Field(User)
guilds = Field(listof(Guild))
class Resumed(GatewayEvent): class Resumed(GatewayEvent):
"""
Sent after a resume completes.
"""
pass pass
@wraps_model(Guild) @wraps_model(Guild)
class GuildCreate(GatewayEvent): class GuildCreate(GatewayEvent):
unavailable = bool """
Sent when a guild is created, or becomes available.
"""
unavailable = Field(bool)
@wraps_model(Guild) @wraps_model(Guild)
class GuildUpdate(GatewayEvent): class GuildUpdate(GatewayEvent):
"""
Sent when a guild is updated.
"""
pass pass
class GuildDelete(GatewayEvent): class GuildDelete(GatewayEvent):
id = snowflake """
unavailable = bool Sent when a guild is deleted, or becomes unavailable.
"""
id = Field(snowflake)
unavailable = Field(bool)
@wraps_model(Channel) @wraps_model(Channel)
class ChannelCreate(GatewayEvent): class ChannelCreate(GatewayEvent):
@property """
def guild(self): Sent when a channel is created.
return self.channel.guild """
@wraps_model(Channel) @wraps_model(Channel)
class ChannelUpdate(ChannelCreate): class ChannelUpdate(ChannelCreate):
"""
Sent when a channel is updated.
"""
pass pass
@wraps_model(Channel) @wraps_model(Channel)
class ChannelDelete(ChannelCreate): class ChannelDelete(ChannelCreate):
"""
Sent when a channel is deleted.
"""
pass pass
class ChannelPinsUpdate(GatewayEvent): class ChannelPinsUpdate(GatewayEvent):
channel_id = snowflake """
last_pin_timestamp = int Sent when a channels pins are updated.
"""
channel_id = Field(snowflake)
last_pin_timestamp = Field(int)
@wraps_model(User) @wraps_model(User)
class GuildBanAdd(GatewayEvent): class GuildBanAdd(GatewayEvent):
"""
Sent when a user is banned from a guild.
"""
pass pass
@wraps_model(User) @wraps_model(User)
class GuildBanRemove(GuildBanAdd): class GuildBanRemove(GuildBanAdd):
"""
Sent when a user is unbanned from a guild.
"""
pass pass
class GuildEmojisUpdate(GatewayEvent): class GuildEmojisUpdate(GatewayEvent):
"""
Sent when a guilds emojis are updated.
"""
pass pass
class GuildIntegrationsUpdate(GatewayEvent): class GuildIntegrationsUpdate(GatewayEvent):
"""
Sent when a guilds integrations are updated.
"""
pass pass
class GuildMembersChunk(GatewayEvent): class GuildMembersChunk(GatewayEvent):
guild_id = snowflake """
members = listof(GuildMember) Sent in response to a members chunk request.
"""
guild_id = Field(snowflake)
members = Field(listof(GuildMember))
@wraps_model(GuildMember, alias='member') @wraps_model(GuildMember, alias='member')
class GuildMemberAdd(GatewayEvent): class GuildMemberAdd(GatewayEvent):
"""
Sent when a user joins a guild.
"""
pass pass
class GuildMemberRemove(GatewayEvent): class GuildMemberRemove(GatewayEvent):
guild_id = snowflake """
user = User Sent when a user leaves a guild (via leaving, kicking, or banning).
"""
guild_id = Field(snowflake)
user = Field(User)
@wraps_model(GuildMember, alias='member') @wraps_model(GuildMember, alias='member')
class GuildMemberUpdate(GatewayEvent): class GuildMemberUpdate(GatewayEvent):
"""
Sent when a guilds member is updated.
"""
pass pass
class GuildRoleCreate(GatewayEvent): class GuildRoleCreate(GatewayEvent):
guild_id = snowflake """
role = Role Sent when a role is created.
"""
guild_id = Field(snowflake)
role = Field(Role)
class GuildRoleUpdate(GuildRoleCreate): class GuildRoleUpdate(GuildRoleCreate):
"""
Sent when a role is updated.
"""
pass pass
class GuildRoleDelete(GuildRoleCreate): class GuildRoleDelete(GuildRoleCreate):
"""
Sent when a role is deleted.
"""
pass pass
@wraps_model(Message) @wraps_model(Message)
class MessageCreate(GatewayEvent): class MessageCreate(GatewayEvent):
@property """
def channel(self): Sent when a message is created.
return self.message.channel """
@wraps_model(Message) @wraps_model(Message)
class MessageUpdate(MessageCreate): class MessageUpdate(MessageCreate):
"""
Sent when a message is updated/edited.
"""
pass pass
class MessageDelete(GatewayEvent): class MessageDelete(GatewayEvent):
id = snowflake """
channel_id = snowflake Sent when a message is deleted.
"""
id = Field(snowflake)
channel_id = Field(snowflake)
class MessageDeleteBulk(GatewayEvent): class MessageDeleteBulk(GatewayEvent):
channel_id = snowflake """
ids = listof(snowflake) Sent when multiple messages are deleted from a channel.
"""
channel_id = Field(snowflake)
ids = Field(listof(snowflake))
class PresenceUpdate(GatewayEvent): class PresenceUpdate(GatewayEvent):
"""
Sent when a users presence is updated.
"""
class Game(Model): class Game(Model):
# TODO enum # TODO enum
type = int type = Field(int)
name = text name = Field(text)
url = text url = Field(text)
user = User user = Field(User)
guild_id = snowflake guild_id = Field(snowflake)
roles = listof(snowflake) roles = Field(listof(snowflake))
game = Game game = Field(Game)
status = text status = Field(text)
class TypingStart(GatewayEvent): class TypingStart(GatewayEvent):
channel_id = snowflake """
user_id = snowflake Sent when a user begins typing in a channel.
timestamp = snowflake """
channel_id = Field(snowflake)
user_id = Field(snowflake)
timestamp = Field(snowflake)
@wraps_model(VoiceState, alias='state') @wraps_model(VoiceState, alias='state')
class VoiceStateUpdate(GatewayEvent): class VoiceStateUpdate(GatewayEvent):
"""
Sent when a users voice state changes.
"""
pass pass
class VoiceServerUpdate(GatewayEvent): class VoiceServerUpdate(GatewayEvent):
token = str """
endpoint = str Sent when a voice server is updated.
guild_id = snowflake """
token = Field(str)
endpoint = Field(str)
guild_id = Field(snowflake)

151
disco/types/base.py

@ -10,6 +10,69 @@ DATETIME_FORMATS = [
] ]
class FieldType(object):
def __init__(self, typ):
if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model):
self.typ = typ
else:
self.typ = lambda raw, _: typ(raw)
def try_convert(self, raw, client):
pass
def __call__(self, raw, client):
return self.try_convert(raw, client)
class Field(FieldType):
def __init__(self, typ, alias=None):
super(Field, self).__init__(typ)
# Set names
self.src_name = alias
self.dst_name = None
self.default = None
if isinstance(self.typ, FieldType):
self.default = self.typ.default
def set_name(self, name):
if not self.dst_name:
self.dst_name = name
if not self.src_name:
self.src_name = name
def has_default(self):
return self.default is not None
def try_convert(self, raw, client):
return self.typ(raw, client)
class _Dict(FieldType):
default = dict
def __init__(self, typ, key=None):
super(_Dict, self).__init__(typ)
self.key = key
def try_convert(self, raw, client):
if self.key:
converted = [self.typ(i, client) for i in raw]
return {getattr(i, self.key): i for i in converted}
else:
return {k: self.typ(v, client) for k, v in six.iteritems(raw)}
class _List(FieldType):
default = list
def try_convert(self, raw, client):
return [self.typ(i, client) for i in raw]
def _make(typ, data, client): def _make(typ, data, client):
if inspect.isclass(typ) and issubclass(typ, Model): if inspect.isclass(typ) and issubclass(typ, Model):
return typ(data, client) return typ(data, client)
@ -26,33 +89,12 @@ def enum(typ):
return _f return _f
def listof(typ): def listof(*args, **kwargs):
def _f(data, client=None): return _List(*args, **kwargs)
if not data:
return []
return [_make(typ, obj, client) for obj in data]
_f._takes_client = None
return _f
def dictof(typ, key=None): def dictof(*args, **kwargs):
def _f(data, client=None): return _Dict(*args, **kwargs)
if not data:
return {}
if key:
return {
getattr(v, key): v for v in (
_make(typ, i, client) for i in data
)}
else:
return {k: _make(typ, v, client) for k, v in six.iteritems(data)}
_f._takes_client = None
return _f
def alias(typ, name):
return ('alias', name, typ)
def datetime(data): def datetime(data):
@ -76,23 +118,21 @@ def binary(obj):
return six.text_type(obj) if obj else six.text_type() return six.text_type(obj) if obj else six.text_type()
def field(typ, alias=None):
pass
class ModelMeta(type): class ModelMeta(type):
def __new__(cls, name, parents, dct): def __new__(cls, name, parents, dct):
fields = {} fields = {}
for k, v in six.iteritems(dct):
if isinstance(v, tuple):
if v[0] == 'alias':
fields[v[1]] = (k, v[2])
continue
if inspect.isclass(v): for k, v in six.iteritems(dct):
fields[k] = v if not isinstance(v, Field):
elif callable(v): continue
args, _, _, _ = inspect.getargspec(v)
if 'self' in args:
continue
fields[k] = v v.set_name(k)
fields[k] = v
dct[k] = None
dct['_fields'] = fields dct['_fields'] = fields
return super(ModelMeta, cls).__new__(cls, name, parents, dct) return super(ModelMeta, cls).__new__(cls, name, parents, dct)
@ -102,40 +142,17 @@ class Model(six.with_metaclass(ModelMeta)):
def __init__(self, obj, client=None): def __init__(self, obj, client=None):
self.client = client self.client = client
for name, typ in self.__class__._fields.items(): for name, field in self._fields.items():
dest_name = name if name not in obj or not obj[field.src_name]:
if field.has_default():
if isinstance(typ, tuple): setattr(self, field.dst_name, field.default())
dest_name, typ = typ
if name not in obj or not obj[name]:
if inspect.isclass(typ) and issubclass(typ, Model):
res = None
elif isinstance(typ, type):
res = typ()
else:
res = typ(None)
setattr(self, dest_name, res)
continue continue
try: value = field.try_convert(obj[field.src_name], client)
if client: setattr(self, field.dst_name, value)
if inspect.isfunction(typ) and hasattr(typ, '_takes_client'):
v = typ(obj[name], client)
elif inspect.isclass(typ) and issubclass(typ, Model):
v = typ(obj[name], client)
else:
v = typ(obj[name])
else:
v = typ(obj[name])
except Exception:
print('Failed during parsing of field {} => {}'.format(name, typ))
raise
setattr(self, dest_name, v)
def update(self, other): def update(self, other):
for name in six.iterkeys(self.__class__._fields): for name in six.iterkeys(self._fields):
value = getattr(other, name) value = getattr(other, name)
if value: if value:
setattr(self, name, value) setattr(self, name, value)

30
disco/types/channel.py

@ -1,6 +1,6 @@
from holster.enum import Enum from holster.enum import Enum
from disco.types.base import Model, snowflake, enum, listof, dictof, alias, text from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text
from disco.types.permissions import PermissionValue from disco.types.permissions import PermissionValue
from disco.util.functional import cached_property from disco.util.functional import cached_property
@ -39,10 +39,10 @@ class PermissionOverwrite(Model):
All denied permissions All denied permissions
""" """
id = snowflake id = Field(snowflake)
type = enum(PermissionOverwriteType) type = Field(enum(PermissionOverwriteType))
allow = PermissionValue allow = Field(PermissionValue)
deny = PermissionValue deny = Field(PermissionValue)
class Channel(Model, Permissible): class Channel(Model, Permissible):
@ -70,16 +70,16 @@ class Channel(Model, Permissible):
overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`) overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`)
Channel permissions overwrites. Channel permissions overwrites.
""" """
id = snowflake id = Field(snowflake)
guild_id = snowflake guild_id = Field(snowflake)
name = text name = Field(text)
topic = text topic = Field(text)
_last_message_id = alias(snowflake, 'last_message_id') _last_message_id = Field(snowflake, alias='last_message_id')
position = int position = Field(int)
bitrate = int bitrate = Field(int)
recipients = listof(User) recipients = Field(listof(User))
type = enum(ChannelType) type = Field(enum(ChannelType))
overwrites = alias(dictof(PermissionOverwrite, key='id'), 'permission_overwrites') overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites')
def get_permissions(self, user): def get_permissions(self, user):
""" """

76
disco/types/guild.py

@ -3,7 +3,7 @@ from holster.enum import Enum
from disco.api.http import APIException from disco.api.http import APIException
from disco.util import to_snowflake from disco.util import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.base import Model, snowflake, listof, dictof, datetime, text, binary, enum from disco.types.base import Model, Field, snowflake, listof, dictof, datetime, text, binary, enum
from disco.types.user import User from disco.types.user import User
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.permissions import PermissionValue, Permissions, Permissible from disco.types.permissions import PermissionValue, Permissions, Permissible
@ -36,11 +36,11 @@ class Emoji(Model):
roles : list(snowflake) roles : list(snowflake)
Roles this emoji is attached to. Roles this emoji is attached to.
""" """
id = snowflake id = Field(snowflake)
name = text name = Field(text)
require_colons = bool require_colons = Field(bool)
managed = bool managed = Field(bool)
roles = listof(snowflake) roles = Field(listof(snowflake))
class Role(Model): class Role(Model):
@ -64,13 +64,13 @@ class Role(Model):
position : int position : int
The position of this role in the hierarchy. The position of this role in the hierarchy.
""" """
id = snowflake id = Field(snowflake)
name = text name = Field(text)
hoist = bool hoist = Field(bool)
managed = bool managed = Field(bool)
color = int color = Field(int)
permissions = PermissionValue permissions = Field(PermissionValue)
position = int position = Field(int)
class GuildMember(Model): class GuildMember(Model):
@ -94,13 +94,13 @@ class GuildMember(Model):
roles : list(snowflake) roles : list(snowflake)
Roles this member is part of. Roles this member is part of.
""" """
user = User user = Field(User)
guild_id = snowflake guild_id = Field(snowflake)
nick = text nick = Field(text)
mute = bool mute = Field(bool)
deaf = bool deaf = Field(bool)
joined_at = datetime joined_at = Field(datetime)
roles = listof(snowflake) roles = Field(listof(snowflake))
def get_voice_state(self): def get_voice_state(self):
""" """
@ -196,24 +196,24 @@ class Guild(Model, Permissible):
All of the guilds voice states. All of the guilds voice states.
""" """
id = snowflake id = Field(snowflake)
owner_id = snowflake owner_id = Field(snowflake)
afk_channel_id = snowflake afk_channel_id = Field(snowflake)
embed_channel_id = snowflake embed_channel_id = Field(snowflake)
name = text name = Field(text)
icon = binary icon = Field(binary)
splash = binary splash = Field(binary)
region = str region = Field(str)
afk_timeout = int afk_timeout = Field(int)
embed_enabled = bool embed_enabled = Field(bool)
verification_level = enum(VerificationLevel) verification_level = Field(enum(VerificationLevel))
mfa_level = int mfa_level = Field(int)
features = listof(str) features = Field(listof(str))
members = dictof(GuildMember, key='id') members = Field(dictof(GuildMember, key='id'))
channels = dictof(Channel, key='id') channels = Field(dictof(Channel, key='id'))
roles = dictof(Role, key='id') roles = Field(dictof(Role, key='id'))
emojis = dictof(Emoji, key='id') emojis = Field(dictof(Emoji, key='id'))
voice_states = dictof(VoiceState, key='session_id') voice_states = Field(dictof(VoiceState, key='session_id'))
def get_permissions(self, user): def get_permissions(self, user):
""" """

20
disco/types/invite.py

@ -1,4 +1,4 @@
from disco.types.base import Model, datetime from disco.types.base import Model, Field, datetime
from disco.types.user import User from disco.types.user import User
from disco.types.guild import Guild from disco.types.guild import Guild
from disco.types.channel import Channel from disco.types.channel import Channel
@ -29,12 +29,12 @@ class Invite(Model):
created_at : datetime created_at : datetime
When this invite was created. When this invite was created.
""" """
code = str code = Field(str)
inviter = User inviter = Field(User)
guild = Guild guild = Field(Guild)
channel = Channel channel = Field(Channel)
max_age = int max_age = Field(int)
max_uses = int max_uses = Field(int)
uses = int uses = Field(int)
temporary = bool temporary = Field(bool)
created_at = datetime created_at = Field(datetime)

54
disco/types/message.py

@ -2,7 +2,7 @@ import re
from holster.enum import Enum from holster.enum import Enum
from disco.types.base import Model, snowflake, text, datetime, dictof, listof, enum from disco.types.base import Model, Field, snowflake, text, datetime, dictof, listof, enum
from disco.util import to_snowflake from disco.util import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.user import User from disco.types.user import User
@ -34,10 +34,10 @@ class MessageEmbed(Model):
url : str url : str
URL of the embed. URL of the embed.
""" """
title = text title = Field(text)
type = str type = Field(str)
description = text description = Field(text)
url = str url = Field(str)
class MessageAttachment(Model): class MessageAttachment(Model):
@ -61,13 +61,13 @@ class MessageAttachment(Model):
width : int width : int
Width of the attachment. Width of the attachment.
""" """
id = str id = Field(str)
filename = text filename = Field(text)
url = str url = Field(str)
proxy_url = str proxy_url = Field(str)
size = int size = Field(int)
height = int height = Field(int)
width = int width = Field(int)
class Message(Model): class Message(Model):
@ -107,21 +107,21 @@ class Message(Model):
attachments : list(:class:`MessageAttachment`) attachments : list(:class:`MessageAttachment`)
All attachments for this message. All attachments for this message.
""" """
id = snowflake id = Field(snowflake)
channel_id = snowflake channel_id = Field(snowflake)
type = enum(MessageType) type = Field(enum(MessageType))
author = User author = Field(User)
content = text content = Field(text)
nonce = snowflake nonce = Field(snowflake)
timestamp = datetime timestamp = Field(datetime)
edited_timestamp = datetime edited_timestamp = Field(datetime)
tts = bool tts = Field(bool)
mention_everyone = bool mention_everyone = Field(bool)
pinned = bool pinned = Field(bool)
mentions = dictof(User, key='id') mentions = Field(dictof(User, key='id'))
mention_roles = listof(snowflake) mention_roles = Field(listof(snowflake))
embeds = listof(MessageEmbed) embeds = Field(listof(MessageEmbed))
attachments = dictof(MessageAttachment, key='id') attachments = Field(dictof(MessageAttachment, key='id'))
def __str__(self): def __str__(self):
return '<Message {} ({})>'.format(self.id, self.channel_id) return '<Message {} ({})>'.format(self.id, self.channel_id)

14
disco/types/user.py

@ -1,13 +1,13 @@
from disco.types.base import Model, snowflake, text, binary from disco.types.base import Model, Field, snowflake, text, binary
class User(Model): class User(Model):
id = snowflake id = Field(snowflake)
username = text username = Field(text)
discriminator = str discriminator = Field(str)
avatar = binary avatar = Field(binary)
verified = bool verified = Field(bool)
email = str email = Field(str)
def to_string(self): def to_string(self):
return '{}#{}'.format(self.username, self.discriminator) return '{}#{}'.format(self.username, self.discriminator)

20
disco/types/voice.py

@ -1,16 +1,16 @@
from disco.types.base import Model, snowflake from disco.types.base import Model, Field, snowflake
class VoiceState(Model): class VoiceState(Model):
session_id = str session_id = Field(str)
guild_id = snowflake guild_id = Field(snowflake)
channel_id = snowflake channel_id = Field(snowflake)
user_id = snowflake user_id = Field(snowflake)
deaf = bool deaf = Field(bool)
mute = bool mute = Field(bool)
self_deaf = bool self_deaf = Field(bool)
self_mute = bool self_mute = Field(bool)
suppress = bool suppress = Field(bool)
@property @property
def guild(self): def guild(self):

2
docs/api.rst

@ -118,7 +118,7 @@ GatewayClient
Gateway Events Gateway Events
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
.. automodule:: disco.gateway.client.Events .. automodule:: disco.gateway.events
:members: :members:

6
examples/basic_plugin.py

@ -10,10 +10,8 @@ from disco.types.permissions import Permissions
class BasicPlugin(Plugin): class BasicPlugin(Plugin):
@Plugin.listen('MessageCreate') @Plugin.listen('MessageCreate')
def on_message_create(self, event): def on_message_create(self, msg):
self.log.info('Message created: <{}>: {}'.format( self.log.info('Message created: {}: {}'.format(msg.author, msg.content))
event.message.author.username,
event.message.content))
@Plugin.command('status', '[component]') @Plugin.command('status', '[component]')
def on_status_command(self, event, component=None): def on_status_command(self, event, component=None):

Loading…
Cancel
Save