From 9891b900d6c28cb1df3f07a6bd732a741ee1e029 Mon Sep 17 00:00:00 2001 From: andrei Date: Sun, 30 Oct 2016 19:49:45 -0700 Subject: [PATCH] Modeling improvements, couple other fixes Modeling fields are now drastically better, no more dictof/listof bullshit, we now properly have ListField/DictField/etc. - Fix setting self nickname --- disco/api/client.py | 3 + disco/api/http.py | 1 + disco/gateway/events.py | 24 +++++--- disco/types/base.py | 126 ++++++++++++++++++++++------------------ disco/types/channel.py | 9 ++- disco/types/guild.py | 40 ++++++++----- disco/types/message.py | 17 +++--- 7 files changed, 131 insertions(+), 89 deletions(-) diff --git a/disco/api/client.py b/disco/api/client.py index df13194..27152e4 100644 --- a/disco/api/client.py +++ b/disco/api/client.py @@ -196,6 +196,9 @@ class APIClient(LoggingClass): def guilds_members_modify(self, guild, member, **kwargs): self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) + def guilds_members_me_nick(self, guild, nick): + self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick}) + def guilds_members_kick(self, guild, member): self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member)) diff --git a/disco/api/http.py b/disco/api/http.py index 7fafbf6..31035e4 100644 --- a/disco/api/http.py +++ b/disco/api/http.py @@ -71,6 +71,7 @@ class Routes(object): GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') + GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6b612ea..6799795 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -9,7 +9,7 @@ from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role, Emoji -from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime +from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} @@ -89,7 +89,7 @@ def wraps_model(model, alias=None): def deco(cls): cls._fields[alias] = Field(model) - cls._fields[alias].set_name(alias) + cls._fields[alias].name = alias cls._wraps_model = (alias, model) cls._proxy = alias return cls @@ -124,8 +124,8 @@ class Ready(GatewayEvent): version = Field(int, alias='v') session_id = Field(str) user = Field(User) - guilds = Field(listof(Guild)) - private_channels = Field(listof(Channel)) + guilds = ListField(Guild) + private_channels = ListField(Guild) class Resumed(GatewayEvent): @@ -293,7 +293,7 @@ class GuildEmojisUpdate(GatewayEvent): The new set of emojis for the guild """ guild_id = Field(snowflake) - emojis = Field(listof(Emoji)) + emojis = ListField(Emoji) class GuildIntegrationsUpdate(GatewayEvent): @@ -320,7 +320,7 @@ class GuildMembersChunk(GatewayEvent): The chunk of members. """ guild_id = Field(snowflake) - members = Field(listof(GuildMember)) + members = ListField(GuildMember) @property def guild(self): @@ -466,6 +466,14 @@ class MessageDelete(GatewayEvent): id = Field(snowflake) channel_id = Field(snowflake) + @property + def channel(self): + return self.client.state.channels.get(self.channel_id) + + @property + def guild(self): + return self.channel.guild + class MessageDeleteBulk(GatewayEvent): """ @@ -479,7 +487,7 @@ class MessageDeleteBulk(GatewayEvent): List of messages being deleted in the channel. """ channel_id = Field(snowflake) - ids = Field(listof(snowflake)) + ids = ListField(snowflake) @wraps_model(Presence) @@ -497,7 +505,7 @@ class PresenceUpdate(GatewayEvent): List of roles the user from the presence is part of. """ guild_id = Field(snowflake) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) class TypingStart(GatewayEvent): diff --git a/disco/types/base.py b/disco/types/base.py index 1fd69fc..3bea43a 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -1,4 +1,5 @@ import six +import sys import gevent import inspect import functools @@ -19,49 +20,31 @@ class ConversionError(Exception): def __init__(self, field, raw, e): super(ConversionError, self).__init__( 'Failed to convert `{}` (`{}`) to {}: {}'.format( - str(raw)[:144], field.src_name, field.typ, e)) + str(raw)[:144], field.src_name, field.deserializer, e)) -class FieldType(object): - def __init__(self, typ): - if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model): - self.typ = typ - elif isinstance(typ, BaseEnumMeta): - self.typ = lambda raw, _: typ.get(raw) - elif typ is None: - self.typ = lambda x, y: None - else: - self.typ = lambda raw, _: typ(raw) - - def serialize(self, value): - if isinstance(value, EnumAttr): - return value.value - elif isinstance(value, Model): - return value.to_dict() - else: - return value - - def try_convert(self, raw, client): - pass - - def __call__(self, raw, client): - return self.try_convert(raw, client) +class Field(object): + def __init__(self, value_type, alias=None, default=None): + self.src_name = alias + self.dst_name = None + if not hasattr(self, 'default'): + self.default = default -class Field(FieldType): - def __init__(self, typ, alias=None, default=None): - super(Field, self).__init__(typ) + self.deserializer = None - # Set names - self.src_name = alias - self.dst_name = None + if value_type: + self.deserializer = self.type_to_deserializer(value_type) - self.default = default + if isinstance(self.deserializer, Field): + self.default = self.deserializer.default - if isinstance(self.typ, FieldType): - self.default = self.typ.default + @property + def name(self): + return None - def set_name(self, name): + @name.setter + def name(self, name): if not self.dst_name: self.dst_name = name @@ -73,31 +56,68 @@ class Field(FieldType): def try_convert(self, raw, client): try: - return self.typ(raw, client) + return self.deserializer(raw, client) except Exception as e: - six.raise_from(ConversionError(self, raw, e), e) + exc_info = sys.exc_info() + raise ConversionError(self, raw, e), exc_info[1], exc_info[2] + @staticmethod + def type_to_deserializer(typ): + if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model): + return typ + elif isinstance(typ, BaseEnumMeta): + return lambda raw, _: typ.get(raw) + elif typ is None: + return lambda x, y: None + else: + return lambda raw, _: typ(raw) -class _Dict(FieldType): + @staticmethod + def serialize(value): + if isinstance(value, EnumAttr): + return value.value + elif isinstance(value, Model): + return value.to_dict() + else: + return value + + def __call__(self, raw, client): + return self.try_convert(raw, client) + + +class DictField(Field): default = HashMap - def __init__(self, typ, key=None): - super(_Dict, self).__init__(typ) - self.key = key + def __init__(self, key_type, value_type=None, **kwargs): + super(DictField, self).__init__(None, **kwargs) + self.key_de = self.type_to_deserializer(key_type) + self.value_de = self.type_to_deserializer(value_type or key_type) def try_convert(self, raw, client): - if self.key: - converted = [self.typ(i, client) for i in raw] - return HashMap({getattr(i, self.key): i for i in converted}) - else: - return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)}) + return HashMap({ + self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw) + }) -class _List(FieldType): +class ListField(Field): default = list def try_convert(self, raw, client): - return [self.typ(i, client) for i in raw] + return [self.deserializer(i, client) for i in raw] + + +class AutoDictField(Field): + default = HashMap + + def __init__(self, value_type, key, **kwargs): + super(AutoDictField, self).__init__(None, **kwargs) + self.value_de = self.type_to_deserializer(value_type) + self.key = key + + def try_convert(self, raw, client): + return HashMap({ + getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw) + }) def _make(typ, data, client): @@ -116,14 +136,6 @@ def enum(typ): return _f -def listof(*args, **kwargs): - return _List(*args, **kwargs) - - -def dictof(*args, **kwargs): - return _Dict(*args, **kwargs) - - def lazy_datetime(data): if not data: return property(lambda: None) @@ -201,7 +213,7 @@ class ModelMeta(type): if not isinstance(v, Field): continue - v.set_name(k) + v.name = k fields[k] = v if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)): diff --git a/disco/types/channel.py b/disco/types/channel.py index 57d241d..97eeb38 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -5,7 +5,7 @@ from holster.enum import Enum from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property, one_or_many, chunks from disco.types.user import User -from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text +from disco.types.base import SlottedModel, Field, ListField, AutoDictField, snowflake, enum, text from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.voice.client import VoiceClient @@ -111,15 +111,18 @@ class Channel(SlottedModel, Permissible): last_message_id = Field(snowflake) position = Field(int) bitrate = Field(int) - recipients = Field(listof(User)) + recipients = ListField(User) type = Field(enum(ChannelType)) - overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') + overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') def __init__(self, *args, **kwargs): super(Channel, self).__init__(*args, **kwargs) self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) + def __str__(self): + return '#{}'.format(self.name) + def get_permissions(self, user): """ Get the permissions a user has in the channel diff --git a/disco/types/guild.py b/disco/types/guild.py index d8a2963..1f7ddd4 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -6,7 +6,9 @@ from disco.gateway.packets import OPCode from disco.api.http import APIException from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property -from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum +from disco.types.base import ( + SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum +) from disco.types.user import User, Presence from disco.types.voice import VoiceState from disco.types.channel import Channel @@ -45,7 +47,7 @@ class GuildEmoji(Emoji): name = Field(text) require_colons = Field(bool) managed = Field(bool) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) @cached_property def guild(self): @@ -128,7 +130,7 @@ class GuildMember(SlottedModel): mute = Field(bool) deaf = Field(bool) joined_at = Field(str) - roles = Field(listof(snowflake)) + roles = ListField(snowflake) def __str__(self): return self.user.__str__() @@ -169,7 +171,10 @@ class GuildMember(SlottedModel): nickname : Optional[str] The nickname (or none to reset) to set. """ - self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') + if self.client.state.me.id == self.user.id: + self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '') + else: + self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') def add_role(self, role): roles = self.roles + [role.id] @@ -196,6 +201,10 @@ class GuildMember(SlottedModel): def guild(self): return self.client.state.guilds.get(self.guild_id) + @cached_property + def permissions(self): + return self.guild.get_permissions(self) + class Guild(SlottedModel, Permissible): """ @@ -252,14 +261,14 @@ class Guild(SlottedModel, Permissible): embed_enabled = Field(bool) verification_level = Field(enum(VerificationLevel)) mfa_level = Field(int) - features = Field(listof(str)) - members = Field(dictof(GuildMember, key='id')) - channels = Field(dictof(Channel, key='id')) - roles = Field(dictof(Role, key='id')) - emojis = Field(dictof(GuildEmoji, key='id')) - voice_states = Field(dictof(VoiceState, key='session_id')) + features = ListField(str) + members = AutoDictField(GuildMember, 'id') + channels = AutoDictField(Channel, 'id') + roles = AutoDictField(Role, 'id') + emojis = AutoDictField(GuildEmoji, 'id') + voice_states = AutoDictField(VoiceState, 'session_id') member_count = Field(int) - presences = Field(listof(Presence)) + presences = ListField(Presence) synced = Field(bool, default=False) @@ -272,7 +281,7 @@ class Guild(SlottedModel, Permissible): self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) - def get_permissions(self, user): + def get_permissions(self, member): """ Get the permissions a user has in this guild. @@ -281,10 +290,13 @@ class Guild(SlottedModel, Permissible): :class:`disco.types.permissions.PermissionValue` Computed permission value for the user. """ - if self.owner_id == user.id: + if not isinstance(member, GuildMember): + member = self.get_member(member) + + # Owner has all permissions + if self.owner_id == member.id: return PermissionValue(Permissions.ADMINISTRATOR) - member = self.get_member(user) value = PermissionValue(self.roles.get(self.id).permissions) for role in map(self.roles.get, member.roles): diff --git a/disco/types/message.py b/disco/types/message.py index ce9822f..e53ccb5 100644 --- a/disco/types/message.py +++ b/disco/types/message.py @@ -2,7 +2,10 @@ import re from holster.enum import Enum -from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum +from disco.types.base import ( + SlottedModel, Field, ListField, AutoDictField, snowflake, text, + lazy_datetime, enum +) from disco.util.snowflake import to_snowflake from disco.util.functional import cached_property from disco.types.user import User @@ -109,7 +112,7 @@ class MessageEmbed(SlottedModel): thumbnail = Field(MessageEmbedThumbnail) video = Field(MessageEmbedVideo) author = Field(MessageEmbedAuthor) - fields = Field(listof(MessageEmbedField)) + fields = ListField(MessageEmbedField) class MessageAttachment(SlottedModel): @@ -191,11 +194,11 @@ class Message(SlottedModel): tts = Field(bool) mention_everyone = Field(bool) pinned = Field(bool) - mentions = Field(dictof(User, key='id')) - mention_roles = Field(listof(snowflake)) - embeds = Field(listof(MessageEmbed)) - attachments = Field(dictof(MessageAttachment, key='id')) - reactions = Field(listof(MessageReaction)) + mentions = AutoDictField(User, 'id') + mention_roles = ListField(snowflake) + embeds = ListField(MessageEmbed) + attachments = AutoDictField(MessageAttachment, 'id') + reactions = ListField(MessageReaction) def __str__(self): return ''.format(self.id, self.channel_id)