diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6d809de..fd250a1 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -2,7 +2,7 @@ import inflection import six from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceState -from disco.types.base import Model, snowflake, alias, listof, text +from disco.types.base import Model, Field, snowflake, listof, text # TODO: clean this... use BaseType, etc @@ -29,164 +29,251 @@ class GatewayEvent(Model): return cls(obj, client) + def __getattr__(self, name): + if hasattr(self, '_wraps_model'): + modname, _ = self._wraps_model + if hasattr(self, modname) and hasattr(getattr(self, modname), name): + return getattr(getattr(self, modname), name) + return object.__getattr__(self, name) + def wraps_model(model, alias=None): alias = alias or model.__name__.lower() def deco(cls): - cls._fields[alias] = model + cls._fields[alias] = Field(model) + cls._fields[alias].set_name(alias) cls._wraps_model = (alias, model) return cls return deco class Ready(GatewayEvent): - version = alias(int, 'v') - session_id = str - user = User - guilds = listof(Guild) + """ + Sent after the initial gateway handshake is complete. Contains data required + for bootstrapping the clients states. + """ + version = Field(int, alias='v') + session_id = Field(str) + user = Field(User) + guilds = Field(listof(Guild)) class Resumed(GatewayEvent): + """ + Sent after a resume completes. + """ pass @wraps_model(Guild) class GuildCreate(GatewayEvent): - unavailable = bool + """ + Sent when a guild is created, or becomes available. + """ + unavailable = Field(bool) @wraps_model(Guild) class GuildUpdate(GatewayEvent): + """ + Sent when a guild is updated. + """ pass class GuildDelete(GatewayEvent): - id = snowflake - unavailable = bool + """ + Sent when a guild is deleted, or becomes unavailable. + """ + id = Field(snowflake) + unavailable = Field(bool) @wraps_model(Channel) class ChannelCreate(GatewayEvent): - @property - def guild(self): - return self.channel.guild + """ + Sent when a channel is created. + """ @wraps_model(Channel) class ChannelUpdate(ChannelCreate): + """ + Sent when a channel is updated. + """ pass @wraps_model(Channel) class ChannelDelete(ChannelCreate): + """ + Sent when a channel is deleted. + """ pass class ChannelPinsUpdate(GatewayEvent): - channel_id = snowflake - last_pin_timestamp = int + """ + Sent when a channels pins are updated. + """ + channel_id = Field(snowflake) + last_pin_timestamp = Field(int) @wraps_model(User) class GuildBanAdd(GatewayEvent): + """ + Sent when a user is banned from a guild. + """ pass @wraps_model(User) class GuildBanRemove(GuildBanAdd): + """ + Sent when a user is unbanned from a guild. + """ pass class GuildEmojisUpdate(GatewayEvent): + """ + Sent when a guilds emojis are updated. + """ pass class GuildIntegrationsUpdate(GatewayEvent): + """ + Sent when a guilds integrations are updated. + """ pass class GuildMembersChunk(GatewayEvent): - guild_id = snowflake - members = listof(GuildMember) + """ + Sent in response to a members chunk request. + """ + guild_id = Field(snowflake) + members = Field(listof(GuildMember)) @wraps_model(GuildMember, alias='member') class GuildMemberAdd(GatewayEvent): + """ + Sent when a user joins a guild. + """ pass class GuildMemberRemove(GatewayEvent): - guild_id = snowflake - user = User + """ + Sent when a user leaves a guild (via leaving, kicking, or banning). + """ + guild_id = Field(snowflake) + user = Field(User) @wraps_model(GuildMember, alias='member') class GuildMemberUpdate(GatewayEvent): + """ + Sent when a guilds member is updated. + """ pass class GuildRoleCreate(GatewayEvent): - guild_id = snowflake - role = Role + """ + Sent when a role is created. + """ + guild_id = Field(snowflake) + role = Field(Role) class GuildRoleUpdate(GuildRoleCreate): + """ + Sent when a role is updated. + """ pass class GuildRoleDelete(GuildRoleCreate): + """ + Sent when a role is deleted. + """ pass @wraps_model(Message) class MessageCreate(GatewayEvent): - @property - def channel(self): - return self.message.channel + """ + Sent when a message is created. + """ @wraps_model(Message) class MessageUpdate(MessageCreate): + """ + Sent when a message is updated/edited. + """ pass class MessageDelete(GatewayEvent): - id = snowflake - channel_id = snowflake + """ + Sent when a message is deleted. + """ + id = Field(snowflake) + channel_id = Field(snowflake) class MessageDeleteBulk(GatewayEvent): - channel_id = snowflake - ids = listof(snowflake) + """ + Sent when multiple messages are deleted from a channel. + """ + channel_id = Field(snowflake) + ids = Field(listof(snowflake)) class PresenceUpdate(GatewayEvent): + """ + Sent when a users presence is updated. + """ class Game(Model): # TODO enum - type = int - name = text - url = text + type = Field(int) + name = Field(text) + url = Field(text) - user = User - guild_id = snowflake - roles = listof(snowflake) - game = Game - status = text + user = Field(User) + guild_id = Field(snowflake) + roles = Field(listof(snowflake)) + game = Field(Game) + status = Field(text) class TypingStart(GatewayEvent): - channel_id = snowflake - user_id = snowflake - timestamp = snowflake + """ + Sent when a user begins typing in a channel. + """ + channel_id = Field(snowflake) + user_id = Field(snowflake) + timestamp = Field(snowflake) @wraps_model(VoiceState, alias='state') class VoiceStateUpdate(GatewayEvent): + """ + Sent when a users voice state changes. + """ pass class VoiceServerUpdate(GatewayEvent): - token = str - endpoint = str - guild_id = snowflake + """ + Sent when a voice server is updated. + """ + token = Field(str) + endpoint = Field(str) + guild_id = Field(snowflake) diff --git a/disco/types/base.py b/disco/types/base.py index fc41614..53f7fe0 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -10,6 +10,69 @@ DATETIME_FORMATS = [ ] +class FieldType(object): + def __init__(self, typ): + if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model): + self.typ = typ + else: + self.typ = lambda raw, _: typ(raw) + + def try_convert(self, raw, client): + pass + + def __call__(self, raw, client): + return self.try_convert(raw, client) + + +class Field(FieldType): + def __init__(self, typ, alias=None): + super(Field, self).__init__(typ) + + # Set names + self.src_name = alias + self.dst_name = None + + self.default = None + + if isinstance(self.typ, FieldType): + self.default = self.typ.default + + def set_name(self, name): + if not self.dst_name: + self.dst_name = name + + if not self.src_name: + self.src_name = name + + def has_default(self): + return self.default is not None + + def try_convert(self, raw, client): + return self.typ(raw, client) + + +class _Dict(FieldType): + default = dict + + def __init__(self, typ, key=None): + super(_Dict, self).__init__(typ) + self.key = key + + def try_convert(self, raw, client): + if self.key: + converted = [self.typ(i, client) for i in raw] + return {getattr(i, self.key): i for i in converted} + else: + return {k: self.typ(v, client) for k, v in six.iteritems(raw)} + + +class _List(FieldType): + default = list + + def try_convert(self, raw, client): + return [self.typ(i, client) for i in raw] + + def _make(typ, data, client): if inspect.isclass(typ) and issubclass(typ, Model): return typ(data, client) @@ -26,33 +89,12 @@ def enum(typ): return _f -def listof(typ): - def _f(data, client=None): - if not data: - return [] - return [_make(typ, obj, client) for obj in data] - _f._takes_client = None - return _f +def listof(*args, **kwargs): + return _List(*args, **kwargs) -def dictof(typ, key=None): - def _f(data, client=None): - if not data: - return {} - - if key: - return { - getattr(v, key): v for v in ( - _make(typ, i, client) for i in data - )} - else: - return {k: _make(typ, v, client) for k, v in six.iteritems(data)} - _f._takes_client = None - return _f - - -def alias(typ, name): - return ('alias', name, typ) +def dictof(*args, **kwargs): + return _Dict(*args, **kwargs) def datetime(data): @@ -76,23 +118,21 @@ def binary(obj): return six.text_type(obj) if obj else six.text_type() +def field(typ, alias=None): + pass + + class ModelMeta(type): def __new__(cls, name, parents, dct): fields = {} - for k, v in six.iteritems(dct): - if isinstance(v, tuple): - if v[0] == 'alias': - fields[v[1]] = (k, v[2]) - continue - if inspect.isclass(v): - fields[k] = v - elif callable(v): - args, _, _, _ = inspect.getargspec(v) - if 'self' in args: - continue + for k, v in six.iteritems(dct): + if not isinstance(v, Field): + continue - fields[k] = v + v.set_name(k) + fields[k] = v + dct[k] = None dct['_fields'] = fields return super(ModelMeta, cls).__new__(cls, name, parents, dct) @@ -102,40 +142,17 @@ class Model(six.with_metaclass(ModelMeta)): def __init__(self, obj, client=None): self.client = client - for name, typ in self.__class__._fields.items(): - dest_name = name - - if isinstance(typ, tuple): - dest_name, typ = typ - - if name not in obj or not obj[name]: - if inspect.isclass(typ) and issubclass(typ, Model): - res = None - elif isinstance(typ, type): - res = typ() - else: - res = typ(None) - setattr(self, dest_name, res) + for name, field in self._fields.items(): + if name not in obj or not obj[field.src_name]: + if field.has_default(): + setattr(self, field.dst_name, field.default()) continue - try: - if client: - if inspect.isfunction(typ) and hasattr(typ, '_takes_client'): - v = typ(obj[name], client) - elif inspect.isclass(typ) and issubclass(typ, Model): - v = typ(obj[name], client) - else: - v = typ(obj[name]) - else: - v = typ(obj[name]) - except Exception: - print('Failed during parsing of field {} => {}'.format(name, typ)) - raise - - setattr(self, dest_name, v) + value = field.try_convert(obj[field.src_name], client) + setattr(self, field.dst_name, value) def update(self, other): - for name in six.iterkeys(self.__class__._fields): + for name in six.iterkeys(self._fields): value = getattr(other, name) if value: setattr(self, name, value) diff --git a/disco/types/channel.py b/disco/types/channel.py index e1beea4..0e5be65 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -1,6 +1,6 @@ from holster.enum import Enum -from disco.types.base import Model, snowflake, enum, listof, dictof, alias, text +from disco.types.base import Model, Field, snowflake, enum, listof, dictof, text from disco.types.permissions import PermissionValue from disco.util.functional import cached_property @@ -39,10 +39,10 @@ class PermissionOverwrite(Model): All denied permissions """ - id = snowflake - type = enum(PermissionOverwriteType) - allow = PermissionValue - deny = PermissionValue + id = Field(snowflake) + type = Field(enum(PermissionOverwriteType)) + allow = Field(PermissionValue) + deny = Field(PermissionValue) class Channel(Model, Permissible): @@ -70,16 +70,16 @@ class Channel(Model, Permissible): overwrites : dict(snowflake, :class:`disco.types.channel.PermissionOverwrite`) Channel permissions overwrites. """ - id = snowflake - guild_id = snowflake - name = text - topic = text - _last_message_id = alias(snowflake, 'last_message_id') - position = int - bitrate = int - recipients = listof(User) - type = enum(ChannelType) - overwrites = alias(dictof(PermissionOverwrite, key='id'), 'permission_overwrites') + id = Field(snowflake) + guild_id = Field(snowflake) + name = Field(text) + topic = Field(text) + _last_message_id = Field(snowflake, alias='last_message_id') + position = Field(int) + bitrate = Field(int) + recipients = Field(listof(User)) + type = Field(enum(ChannelType)) + overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') def get_permissions(self, user): """ diff --git a/disco/types/guild.py b/disco/types/guild.py index 02d3c0d..10d93eb 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -3,7 +3,7 @@ from holster.enum import Enum from disco.api.http import APIException from disco.util import to_snowflake from disco.util.functional import cached_property -from disco.types.base import Model, snowflake, listof, dictof, datetime, text, binary, enum +from disco.types.base import Model, Field, snowflake, listof, dictof, datetime, text, binary, enum from disco.types.user import User from disco.types.voice import VoiceState from disco.types.permissions import PermissionValue, Permissions, Permissible @@ -36,11 +36,11 @@ class Emoji(Model): roles : list(snowflake) Roles this emoji is attached to. """ - id = snowflake - name = text - require_colons = bool - managed = bool - roles = listof(snowflake) + id = Field(snowflake) + name = Field(text) + require_colons = Field(bool) + managed = Field(bool) + roles = Field(listof(snowflake)) class Role(Model): @@ -64,13 +64,13 @@ class Role(Model): position : int The position of this role in the hierarchy. """ - id = snowflake - name = text - hoist = bool - managed = bool - color = int - permissions = PermissionValue - position = int + id = Field(snowflake) + name = Field(text) + hoist = Field(bool) + managed = Field(bool) + color = Field(int) + permissions = Field(PermissionValue) + position = Field(int) class GuildMember(Model): @@ -94,13 +94,13 @@ class GuildMember(Model): roles : list(snowflake) Roles this member is part of. """ - user = User - guild_id = snowflake - nick = text - mute = bool - deaf = bool - joined_at = datetime - roles = listof(snowflake) + user = Field(User) + guild_id = Field(snowflake) + nick = Field(text) + mute = Field(bool) + deaf = Field(bool) + joined_at = Field(datetime) + roles = Field(listof(snowflake)) def get_voice_state(self): """ @@ -196,24 +196,24 @@ class Guild(Model, Permissible): All of the guilds voice states. """ - id = snowflake - owner_id = snowflake - afk_channel_id = snowflake - embed_channel_id = snowflake - name = text - icon = binary - splash = binary - region = str - afk_timeout = int - embed_enabled = bool - verification_level = enum(VerificationLevel) - mfa_level = int - features = listof(str) - members = dictof(GuildMember, key='id') - channels = dictof(Channel, key='id') - roles = dictof(Role, key='id') - emojis = dictof(Emoji, key='id') - voice_states = dictof(VoiceState, key='session_id') + id = Field(snowflake) + owner_id = Field(snowflake) + afk_channel_id = Field(snowflake) + embed_channel_id = Field(snowflake) + name = Field(text) + icon = Field(binary) + splash = Field(binary) + region = Field(str) + afk_timeout = Field(int) + embed_enabled = Field(bool) + verification_level = Field(enum(VerificationLevel)) + mfa_level = Field(int) + features = Field(listof(str)) + members = Field(dictof(GuildMember, key='id')) + channels = Field(dictof(Channel, key='id')) + roles = Field(dictof(Role, key='id')) + emojis = Field(dictof(Emoji, key='id')) + voice_states = Field(dictof(VoiceState, key='session_id')) def get_permissions(self, user): """ diff --git a/disco/types/invite.py b/disco/types/invite.py index f8fbc43..2bfb355 100644 --- a/disco/types/invite.py +++ b/disco/types/invite.py @@ -1,4 +1,4 @@ -from disco.types.base import Model, datetime +from disco.types.base import Model, Field, datetime from disco.types.user import User from disco.types.guild import Guild from disco.types.channel import Channel @@ -29,12 +29,12 @@ class Invite(Model): created_at : datetime When this invite was created. """ - code = str - inviter = User - guild = Guild - channel = Channel - max_age = int - max_uses = int - uses = int - temporary = bool - created_at = datetime + code = Field(str) + inviter = Field(User) + guild = Field(Guild) + channel = Field(Channel) + max_age = Field(int) + max_uses = Field(int) + uses = Field(int) + temporary = Field(bool) + created_at = Field(datetime) diff --git a/disco/types/message.py b/disco/types/message.py index de6d96e..ba88328 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -2,7 +2,7 @@ import re from holster.enum import Enum -from disco.types.base import Model, snowflake, text, datetime, dictof, listof, enum +from disco.types.base import Model, Field, snowflake, text, datetime, dictof, listof, enum from disco.util import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -34,10 +34,10 @@ class MessageEmbed(Model): url : str URL of the embed. """ - title = text - type = str - description = text - url = str + title = Field(text) + type = Field(str) + description = Field(text) + url = Field(str) class MessageAttachment(Model): @@ -61,13 +61,13 @@ class MessageAttachment(Model): width : int Width of the attachment. """ - id = str - filename = text - url = str - proxy_url = str - size = int - height = int - width = int + id = Field(str) + filename = Field(text) + url = Field(str) + proxy_url = Field(str) + size = Field(int) + height = Field(int) + width = Field(int) class Message(Model): @@ -107,21 +107,21 @@ class Message(Model): attachments : list(:class:`MessageAttachment`) All attachments for this message. """ - id = snowflake - channel_id = snowflake - type = enum(MessageType) - author = User - content = text - nonce = snowflake - timestamp = datetime - edited_timestamp = datetime - tts = bool - mention_everyone = bool - pinned = bool - mentions = dictof(User, key='id') - mention_roles = listof(snowflake) - embeds = listof(MessageEmbed) - attachments = dictof(MessageAttachment, key='id') + id = Field(snowflake) + channel_id = Field(snowflake) + type = Field(enum(MessageType)) + author = Field(User) + content = Field(text) + nonce = Field(snowflake) + timestamp = Field(datetime) + edited_timestamp = Field(datetime) + tts = Field(bool) + mention_everyone = Field(bool) + pinned = Field(bool) + mentions = Field(dictof(User, key='id')) + mention_roles = Field(listof(snowflake)) + embeds = Field(listof(MessageEmbed)) + attachments = Field(dictof(MessageAttachment, key='id')) def __str__(self): return ''.format(self.id, self.channel_id) diff --git a/disco/types/user.py b/disco/types/user.py index 73f6ff7..160f983 100644 --- a/disco/types/user.py +++ b/disco/types/user.py @@ -1,13 +1,13 @@ -from disco.types.base import Model, snowflake, text, binary +from disco.types.base import Model, Field, snowflake, text, binary class User(Model): - id = snowflake - username = text - discriminator = str - avatar = binary - verified = bool - email = str + id = Field(snowflake) + username = Field(text) + discriminator = Field(str) + avatar = Field(binary) + verified = Field(bool) + email = Field(str) def to_string(self): return '{}#{}'.format(self.username, self.discriminator) diff --git a/disco/types/voice.py b/disco/types/voice.py index a73e180..aadfa58 100644 --- a/disco/types/voice.py +++ b/disco/types/voice.py @@ -1,16 +1,16 @@ -from disco.types.base import Model, snowflake +from disco.types.base import Model, Field, snowflake class VoiceState(Model): - session_id = str - guild_id = snowflake - channel_id = snowflake - user_id = snowflake - deaf = bool - mute = bool - self_deaf = bool - self_mute = bool - suppress = bool + session_id = Field(str) + guild_id = Field(snowflake) + channel_id = Field(snowflake) + user_id = Field(snowflake) + deaf = Field(bool) + mute = Field(bool) + self_deaf = Field(bool) + self_mute = Field(bool) + suppress = Field(bool) @property def guild(self): diff --git a/docs/api.rst b/docs/api.rst index 808483d..c090248 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -118,7 +118,7 @@ GatewayClient Gateway Events ~~~~~~~~~~~~~~ -.. automodule:: disco.gateway.client.Events +.. automodule:: disco.gateway.events :members: diff --git a/examples/basic_plugin.py b/examples/basic_plugin.py index 3bc1746..8796745 100644 --- a/examples/basic_plugin.py +++ b/examples/basic_plugin.py @@ -10,10 +10,8 @@ from disco.types.permissions import Permissions class BasicPlugin(Plugin): @Plugin.listen('MessageCreate') - def on_message_create(self, event): - self.log.info('Message created: <{}>: {}'.format( - event.message.author.username, - event.message.content)) + def on_message_create(self, msg): + self.log.info('Message created: {}: {}'.format(msg.author, msg.content)) @Plugin.command('status', '[component]') def on_status_command(self, event, component=None):