Browse Source

Refactor modeling to avoid magic as much as possible

My previous stab at implementing the simple-modeling-orm-thing-tm failed
in the aspect that there was a lot of duplicated code doing runtime
inspection of stuff. This was due mostly to having no extra place to
store information on types, making it hard to introspect how the type
expected to be built, whether it had a default, etc.

This commit refactors the modeling code to actually have a Field type,
which wraps some information up in a simple class and allows extremely
easy conversion without having to do (more) expensive runtime
inspection. This also gives us the benefits of a much more
readable/cleaner code, expandable field options, and not having to fuck
with sphinx to get docs working correctly (it was duping attributes
because they where aliases...)
pull/5/head
Andrei 9 years ago
parent
commit
ffe5a6f6c8
  1. 169
      disco/gateway/events.py
  2. 151
      disco/types/base.py
  3. 30
      disco/types/channel.py
  4. 76
      disco/types/guild.py
  5. 20
      disco/types/invite.py
  6. 54
      disco/types/message.py
  7. 14
      disco/types/user.py
  8. 20
      disco/types/voice.py
  9. 2
      docs/api.rst
  10. 6
      examples/basic_plugin.py

169
disco/gateway/events.py

@ -2,7 +2,7 @@ import inflection
import six
from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState
from disco.types.base import Model, snowflake, alias, listof, text
from disco.types.base import Model, Field, snowflake, listof, text
# TODO: clean this... use BaseType, etc
@ -29,164 +29,251 @@ class GatewayEvent(Model):
return cls(obj, client)
def __getattr__(self, name):
if hasattr(self, '_wraps_model'):
modname, _ = self._wraps_model
if hasattr(self, modname) and hasattr(getattr(self, modname), name):
return getattr(getattr(self, modname), name)
return object.__getattr__(self, name)
def wraps_model(model, alias=None):
alias = alias or model.__name__.lower()
def deco(cls):
cls._fields[alias] = model
cls._fields[alias] = Field(model)
cls._fields[alias].set_name(alias)
cls._wraps_model = (alias, model)
return cls
return deco
class Ready(GatewayEvent):
version = alias(int, 'v')
session_id = str
user = User
guilds = listof(Guild)
"""
Sent after the initial gateway handshake is complete. Contains data required
for bootstrapping the clients states.
"""
version = Field(int, alias='v')
session_id = Field(str)
user = Field(User)
guilds = Field(listof(Guild))
class Resumed(GatewayEvent):
"""
Sent after a resume completes.
"""
pass
@wraps_model(Guild)
class GuildCreate(GatewayEvent):
unavailable = bool
"""
Sent when a guild is created, or becomes available.
"""
unavailable = Field(bool)
@wraps_model(Guild)
class GuildUpdate(GatewayEvent):
"""
Sent when a guild is updated.
"""
pass
class GuildDelete(GatewayEvent):
id = snowflake
unavailable = bool
"""
Sent when a guild is deleted, or becomes unavailable.
"""
id = Field(snowflake)
unavailable = Field(bool)
@wraps_model(Channel)
class ChannelCreate(GatewayEvent):
@property
def guild(self):
return self.channel.guild
"""
Sent when a channel is created.
"""
@wraps_model(Channel)
class ChannelUpdate(ChannelCreate):
"""
Sent when a channel is updated.
"""
pass
@wraps_model(Channel)
class ChannelDelete(ChannelCreate):
"""
Sent when a channel is deleted.
"""
pass
class ChannelPinsUpdate(GatewayEvent):
channel_id = snowflake
last_pin_timestamp = int
"""
Sent when a channels pins are updated.
"""
channel_id = Field(snowflake)
last_pin_timestamp = Field(int)
@wraps_model(User)
class GuildBanAdd(GatewayEvent):
"""
Sent when a user is banned from a guild.
"""
pass
@wraps_model(User)
class GuildBanRemove(GuildBanAdd):
"""
Sent when a user is unbanned from a guild.
"""
pass
class GuildEmojisUpdate(GatewayEvent):
"""
Sent when a guilds emojis are updated.
"""
pass
class GuildIntegrationsUpdate(GatewayEvent):
"""
Sent when a guilds integrations are updated.
"""
pass
class GuildMembersChunk(GatewayEvent):
guild_id = snowflake
members = listof(GuildMember)
"""
Sent in response to a members chunk request.
"""
guild_id = Field(snowflake)
members = Field(listof(GuildMember))
@wraps_model(GuildMember, alias='member')
class GuildMemberAdd(GatewayEvent):
"""
Sent when a user joins a guild.
"""
pass
class GuildMemberRemove(GatewayEvent):
guild_id = snowflake
user = User
"""
Sent when a user leaves a guild (via leaving, kicking, or banning).
"""
guild_id = Field(snowflake)
user = Field(User)
@wraps_model(GuildMember, alias='member')
class GuildMemberUpdate(GatewayEvent):
"""
Sent when a guilds member is updated.
"""
pass
class GuildRoleCreate(GatewayEvent):
guild_id = snowflake
role = Role
"""
Sent when a role is created.
"""
guild_id = Field(snowflake)
role = Field(Role)
class GuildRoleUpdate(GuildRoleCreate):
"""
Sent when a role is updated.
"""
pass
class GuildRoleDelete(GuildRoleCreate):
"""
Sent when a role is deleted.
"""
pass
@wraps_model(Message)
class MessageCreate(GatewayEvent):
@property
def channel(self):
return self.message.channel
"""
Sent when a message is created.
"""
@wraps_model(Message)
class MessageUpdate(MessageCreate):
"""
Sent when a message is updated/edited.
"""
pass
class MessageDelete(GatewayEvent):
id = snowflake
channel_id = snowflake
"""
Sent when a message is deleted.
"""
id = Field(snowflake)
channel_id = Field(snowflake)
class MessageDeleteBulk(GatewayEvent):
channel_id = snowflake
ids = listof(snowflake)
"""
Sent when multiple messages are deleted from a channel.
"""
channel_id = Field(snowflake)
ids = Field(listof(snowflake))
class PresenceUpdate(GatewayEvent):
"""
Sent when a users presence is updated.
"""
class Game(Model):
# TODO enum
type = int
name = text
url = text
type = Field(int)
name = Field(text)
url = Field(text)
user = User
guild_id = snowflake
roles = listof(snowflake)
game = Game
status = text
user = Field(User)
guild_id = Field(snowflake)
roles = Field(listof(snowflake))
game = Field(Game)
status = Field(text)
class TypingStart(GatewayEvent):
channel_id = snowflake
user_id = snowflake
timestamp = snowflake
"""
Sent when a user begins typing in a channel.
"""
channel_id = Field(snowflake)
user_id = Field(snowflake)
timestamp = Field(snowflake)
@wraps_model(VoiceState, alias='state')
class VoiceStateUpdate(GatewayEvent):
"""
Sent when a users voice state changes.
"""
pass
class VoiceServerUpdate(GatewayEvent):
token = str
endpoint = str
guild_id = snowflake
"""
Sent when a voice server is updated.
"""
token = Field(str)
endpoint = Field(str)
guild_id = Field(snowflake)

151
disco/types/base.py

@ -10,6 +10,69 @@ DATETIME_FORMATS = [
]
class FieldType(object):
def __init__(self, typ):
if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model):
self.typ = typ
else:
self.typ = lambda raw, _: typ(raw)
def try_convert(self, raw, client):
pass
def __call__(self, raw, client):
return self.try_convert(raw, client)
class Field(FieldType):
def __init__(self, typ, alias=None):
super(Field, self).__init__(typ)
# Set names
self.src_name = alias
self.dst_name = None
self.default = None
if isinstance(self.typ, FieldType):
self.default = self.typ.default
def set_name(self, name):
if not self.dst_name:
self.dst_name = name
if not self.src_name:
self.src_name = name
def has_default(self):
return self.default is not None
def try_convert(self, raw, client):
return self.typ(raw, client)
class _Dict(FieldType):
default = dict
def __init__(self, typ, key=None):
super(_Dict, self).__init__(typ)
self.key = key
def try_convert(self, raw, client):
if self.key:
converted = [self.typ(i, client) for i in raw]
return {getattr(i, self.key): i for i in converted}
else:
return {k: self.typ(v, client) for k, v in six.iteritems(raw)}
class _List(FieldType):
default = list
def try_convert(self, raw, client):
return [self.typ(i, client) for i in raw]
def _make(typ, data, client):
if inspect.isclass(typ) and issubclass(typ, Model):
return typ(data, client)
@ -26,33 +89,12 @@ def enum(typ):
return _f
def listof(typ):
def _f(data, client=None):
if not data:
return []
return [_make(typ, obj, client) for obj in data]
_f._takes_client = None
return _f
def listof(*args, **kwargs):
return _List(*args, **kwargs)
def dictof(typ, key=None):
def _f(data, client=None):
if not data:
return {}
if key:
return {
getattr(v, key): v for v in (
_make(typ, i, client) for i in data
)}
else:
return {k: _make(typ, v, client) for k, v in six.iteritems(data)}
_f._takes_client = None
return _f
def alias(typ, name):
return ('alias', name, typ)
def dictof(*args, **kwargs):
return _Dict(*args, **kwargs)
def datetime(data):
@ -76,23 +118,21 @@ def binary(obj):
return six.text_type(obj) if obj else six.text_type()
def field(typ, alias=None):
pass
class ModelMeta(type):
def __new__(cls, name, parents, dct):
fields = {}
for k, v in six.iteritems(dct):
if isinstance(v, tuple):
if v[0] == 'alias':
fields[v[1]] = (k, v[2])
continue
if inspect.isclass(v):
fields[k] = v
elif callable(v):
args, _, _, _ = inspect.getargspec(v)
if 'self' in args:
continue
for k, v in six.iteritems(dct):
if not isinstance(v, Field):
continue
fields[k] = v
v.set_name(k)
fields[k] = v
dct[k] = None
dct['_fields'] = fields
return super(ModelMeta, cls).__new__(cls, name, parents, dct)
@ -102,40 +142,17 @@ class Model(six.with_metaclass(ModelMeta)):
def __init__(self, obj, client=None):
self.client = client
for name, typ in self.__class__._fields.items():
dest_name = name
if isinstance(typ, tuple):
dest_name, typ = typ
if name not in obj or not obj[name]:
if inspect.isclass(typ) and issubclass(typ, Model):
res = None
elif isinstance(typ, type):
res = typ()
else:
res = typ(None)
setattr(self, dest_name, res)
for name, field in self._fields.items():
if name not in obj or not obj[field.src_name]:
if field.has_default():
setattr(self, field.dst_name, field.default())
continue
try:
if client:
if inspect.isfunction(typ) and hasattr(typ, '_takes_client'):
v = typ(obj[name], client)
elif inspect.isclass(typ) and issubclass(typ, Model):
v = typ(obj[name], client)
else:
v = typ(obj[name])
else:
v = typ(obj[name])
except Exception:
print('Failed during parsing of field {} => {}'.format(name, typ))
raise
setattr(self, dest_name, v)
value = field.try_convert(obj[field.src_name], client)
setattr(self, field.dst_name, value)
def update(self, other):
for name in six.iterkeys(self.__class__._fields):
for name in six.iterkeys(self._fields):
value = getattr(other, name)
if value:
setattr(self, name, value)

30
disco/types/channel.py

@ -1,6 +1,6 @@
from holster.enum import Enum
from disco.types.base import Model, snowflake, enum, listof, dictof, alias, text
from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text
from disco.types.permissions import PermissionValue
from disco.util.functional import cached_property
@ -39,10 +39,10 @@ class PermissionOverwrite(Model):
All denied permissions
"""
id = snowflake
type = enum(PermissionOverwriteType)
allow = PermissionValue
deny = PermissionValue
id = Field(snowflake)
type = Field(enum(PermissionOverwriteType))
allow = Field(PermissionValue)
deny = Field(PermissionValue)
class Channel(Model, Permissible):
@ -70,16 +70,16 @@ class Channel(Model, Permissible):
overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`)
Channel permissions overwrites.
"""
id = snowflake
guild_id = snowflake
name = text
topic = text
_last_message_id = alias(snowflake, 'last_message_id')
position = int
bitrate = int
recipients = listof(User)
type = enum(ChannelType)
overwrites = alias(dictof(PermissionOverwrite, key='id'), 'permission_overwrites')
id = Field(snowflake)
guild_id = Field(snowflake)
name = Field(text)
topic = Field(text)
_last_message_id = Field(snowflake, alias='last_message_id')
position = Field(int)
bitrate = Field(int)
recipients = Field(listof(User))
type = Field(enum(ChannelType))
overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites')
def get_permissions(self, user):
"""

76
disco/types/guild.py

@ -3,7 +3,7 @@ from holster.enum import Enum
from disco.api.http import APIException
from disco.util import to_snowflake
from disco.util.functional import cached_property
from disco.types.base import Model, snowflake, listof, dictof, datetime, text, binary, enum
from disco.types.base import Model, Field, snowflake, listof, dictof, datetime, text, binary, enum
from disco.types.user import User
from disco.types.voice import VoiceState
from disco.types.permissions import PermissionValue, Permissions, Permissible
@ -36,11 +36,11 @@ class Emoji(Model):
roles : list(snowflake)
Roles this emoji is attached to.
"""
id = snowflake
name = text
require_colons = bool
managed = bool
roles = listof(snowflake)
id = Field(snowflake)
name = Field(text)
require_colons = Field(bool)
managed = Field(bool)
roles = Field(listof(snowflake))
class Role(Model):
@ -64,13 +64,13 @@ class Role(Model):
position : int
The position of this role in the hierarchy.
"""
id = snowflake
name = text
hoist = bool
managed = bool
color = int
permissions = PermissionValue
position = int
id = Field(snowflake)
name = Field(text)
hoist = Field(bool)
managed = Field(bool)
color = Field(int)
permissions = Field(PermissionValue)
position = Field(int)
class GuildMember(Model):
@ -94,13 +94,13 @@ class GuildMember(Model):
roles : list(snowflake)
Roles this member is part of.
"""
user = User
guild_id = snowflake
nick = text
mute = bool
deaf = bool
joined_at = datetime
roles = listof(snowflake)
user = Field(User)
guild_id = Field(snowflake)
nick = Field(text)
mute = Field(bool)
deaf = Field(bool)
joined_at = Field(datetime)
roles = Field(listof(snowflake))
def get_voice_state(self):
"""
@ -196,24 +196,24 @@ class Guild(Model, Permissible):
All of the guilds voice states.
"""
id = snowflake
owner_id = snowflake
afk_channel_id = snowflake
embed_channel_id = snowflake
name = text
icon = binary
splash = binary
region = str
afk_timeout = int
embed_enabled = bool
verification_level = enum(VerificationLevel)
mfa_level = int
features = listof(str)
members = dictof(GuildMember, key='id')
channels = dictof(Channel, key='id')
roles = dictof(Role, key='id')
emojis = dictof(Emoji, key='id')
voice_states = dictof(VoiceState, key='session_id')
id = Field(snowflake)
owner_id = Field(snowflake)
afk_channel_id = Field(snowflake)
embed_channel_id = Field(snowflake)
name = Field(text)
icon = Field(binary)
splash = Field(binary)
region = Field(str)
afk_timeout = Field(int)
embed_enabled = Field(bool)
verification_level = Field(enum(VerificationLevel))
mfa_level = Field(int)
features = Field(listof(str))
members = Field(dictof(GuildMember, key='id'))
channels = Field(dictof(Channel, key='id'))
roles = Field(dictof(Role, key='id'))
emojis = Field(dictof(Emoji, key='id'))
voice_states = Field(dictof(VoiceState, key='session_id'))
def get_permissions(self, user):
"""

20
disco/types/invite.py

@ -1,4 +1,4 @@
from disco.types.base import Model, datetime
from disco.types.base import Model, Field, datetime
from disco.types.user import User
from disco.types.guild import Guild
from disco.types.channel import Channel
@ -29,12 +29,12 @@ class Invite(Model):
created_at : datetime
When this invite was created.
"""
code = str
inviter = User
guild = Guild
channel = Channel
max_age = int
max_uses = int
uses = int
temporary = bool
created_at = datetime
code = Field(str)
inviter = Field(User)
guild = Field(Guild)
channel = Field(Channel)
max_age = Field(int)
max_uses = Field(int)
uses = Field(int)
temporary = Field(bool)
created_at = Field(datetime)

54
disco/types/message.py

@ -2,7 +2,7 @@ import re
from holster.enum import Enum
from disco.types.base import Model, snowflake, text, datetime, dictof, listof, enum
from disco.types.base import Model, Field, snowflake, text, datetime, dictof, listof, enum
from disco.util import to_snowflake
from disco.util.functional import cached_property
from disco.types.user import User
@ -34,10 +34,10 @@ class MessageEmbed(Model):
url : str
URL of the embed.
"""
title = text
type = str
description = text
url = str
title = Field(text)
type = Field(str)
description = Field(text)
url = Field(str)
class MessageAttachment(Model):
@ -61,13 +61,13 @@ class MessageAttachment(Model):
width : int
Width of the attachment.
"""
id = str
filename = text
url = str
proxy_url = str
size = int
height = int
width = int
id = Field(str)
filename = Field(text)
url = Field(str)
proxy_url = Field(str)
size = Field(int)
height = Field(int)
width = Field(int)
class Message(Model):
@ -107,21 +107,21 @@ class Message(Model):
attachments : list(:class:`MessageAttachment`)
All attachments for this message.
"""
id = snowflake
channel_id = snowflake
type = enum(MessageType)
author = User
content = text
nonce = snowflake
timestamp = datetime
edited_timestamp = datetime
tts = bool
mention_everyone = bool
pinned = bool
mentions = dictof(User, key='id')
mention_roles = listof(snowflake)
embeds = listof(MessageEmbed)
attachments = dictof(MessageAttachment, key='id')
id = Field(snowflake)
channel_id = Field(snowflake)
type = Field(enum(MessageType))
author = Field(User)
content = Field(text)
nonce = Field(snowflake)
timestamp = Field(datetime)
edited_timestamp = Field(datetime)
tts = Field(bool)
mention_everyone = Field(bool)
pinned = Field(bool)
mentions = Field(dictof(User, key='id'))
mention_roles = Field(listof(snowflake))
embeds = Field(listof(MessageEmbed))
attachments = Field(dictof(MessageAttachment, key='id'))
def __str__(self):
return '<Message {} ({})>'.format(self.id, self.channel_id)

14
disco/types/user.py

@ -1,13 +1,13 @@
from disco.types.base import Model, snowflake, text, binary
from disco.types.base import Model, Field, snowflake, text, binary
class User(Model):
id = snowflake
username = text
discriminator = str
avatar = binary
verified = bool
email = str
id = Field(snowflake)
username = Field(text)
discriminator = Field(str)
avatar = Field(binary)
verified = Field(bool)
email = Field(str)
def to_string(self):
return '{}#{}'.format(self.username, self.discriminator)

20
disco/types/voice.py

@ -1,16 +1,16 @@
from disco.types.base import Model, snowflake
from disco.types.base import Model, Field, snowflake
class VoiceState(Model):
session_id = str
guild_id = snowflake
channel_id = snowflake
user_id = snowflake
deaf = bool
mute = bool
self_deaf = bool
self_mute = bool
suppress = bool
session_id = Field(str)
guild_id = Field(snowflake)
channel_id = Field(snowflake)
user_id = Field(snowflake)
deaf = Field(bool)
mute = Field(bool)
self_deaf = Field(bool)
self_mute = Field(bool)
suppress = Field(bool)
@property
def guild(self):

2
docs/api.rst

@ -118,7 +118,7 @@ GatewayClient
Gateway Events
~~~~~~~~~~~~~~
.. automodule:: disco.gateway.client.Events
.. automodule:: disco.gateway.events
:members:

6
examples/basic_plugin.py

@ -10,10 +10,8 @@ from disco.types.permissions import Permissions
class BasicPlugin(Plugin):
@Plugin.listen('MessageCreate')
def on_message_create(self, event):
self.log.info('Message created: <{}>: {}'.format(
event.message.author.username,
event.message.content))
def on_message_create(self, msg):
self.log.info('Message created: {}: {}'.format(msg.author, msg.content))
@Plugin.command('status', '[component]')
def on_status_command(self, event, component=None):

Loading…
Cancel
Save