Browse Source

Modeling improvements, couple other fixes

Modeling fields are now drastically better, no more dictof/listof
bullshit, we now properly have ListField/DictField/etc.

- Fix setting self nickname
pull/10/head
andrei 9 years ago
parent
commit
9891b900d6
  1. 3
      disco/api/client.py
  2. 1
      disco/api/http.py
  3. 24
      disco/gateway/events.py
  4. 126
      disco/types/base.py
  5. 9
      disco/types/channel.py
  6. 40
      disco/types/guild.py
  7. 17
      disco/types/message.py

3
disco/api/client.py

@ -196,6 +196,9 @@ class APIClient(LoggingClass):
def guilds_members_modify(self, guild, member, **kwargs):
self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs)
def guilds_members_me_nick(self, guild, nick):
self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick})
def guilds_members_kick(self, guild, member):
self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member))

1
disco/api/http.py

@ -71,6 +71,7 @@ class Routes(object):
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}')
GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}')
GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}')

24
disco/gateway/events.py

@ -9,7 +9,7 @@ from disco.types.message import Message, MessageReactionEmoji
from disco.types.voice import VoiceState
from disco.types.guild import Guild, GuildMember, Role, Emoji
from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime
from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime
# Mapping of discords event name to our event classes
EVENTS_MAP = {}
@ -89,7 +89,7 @@ def wraps_model(model, alias=None):
def deco(cls):
cls._fields[alias] = Field(model)
cls._fields[alias].set_name(alias)
cls._fields[alias].name = alias
cls._wraps_model = (alias, model)
cls._proxy = alias
return cls
@ -124,8 +124,8 @@ class Ready(GatewayEvent):
version = Field(int, alias='v')
session_id = Field(str)
user = Field(User)
guilds = Field(listof(Guild))
private_channels = Field(listof(Channel))
guilds = ListField(Guild)
private_channels = ListField(Guild)
class Resumed(GatewayEvent):
@ -293,7 +293,7 @@ class GuildEmojisUpdate(GatewayEvent):
The new set of emojis for the guild
"""
guild_id = Field(snowflake)
emojis = Field(listof(Emoji))
emojis = ListField(Emoji)
class GuildIntegrationsUpdate(GatewayEvent):
@ -320,7 +320,7 @@ class GuildMembersChunk(GatewayEvent):
The chunk of members.
"""
guild_id = Field(snowflake)
members = Field(listof(GuildMember))
members = ListField(GuildMember)
@property
def guild(self):
@ -466,6 +466,14 @@ class MessageDelete(GatewayEvent):
id = Field(snowflake)
channel_id = Field(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageDeleteBulk(GatewayEvent):
"""
@ -479,7 +487,7 @@ class MessageDeleteBulk(GatewayEvent):
List of messages being deleted in the channel.
"""
channel_id = Field(snowflake)
ids = Field(listof(snowflake))
ids = ListField(snowflake)
@wraps_model(Presence)
@ -497,7 +505,7 @@ class PresenceUpdate(GatewayEvent):
List of roles the user from the presence is part of.
"""
guild_id = Field(snowflake)
roles = Field(listof(snowflake))
roles = ListField(snowflake)
class TypingStart(GatewayEvent):

126
disco/types/base.py

@ -1,4 +1,5 @@
import six
import sys
import gevent
import inspect
import functools
@ -19,49 +20,31 @@ class ConversionError(Exception):
def __init__(self, field, raw, e):
super(ConversionError, self).__init__(
'Failed to convert `{}` (`{}`) to {}: {}'.format(
str(raw)[:144], field.src_name, field.typ, e))
str(raw)[:144], field.src_name, field.deserializer, e))
class FieldType(object):
def __init__(self, typ):
if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model):
self.typ = typ
elif isinstance(typ, BaseEnumMeta):
self.typ = lambda raw, _: typ.get(raw)
elif typ is None:
self.typ = lambda x, y: None
else:
self.typ = lambda raw, _: typ(raw)
def serialize(self, value):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict()
else:
return value
def try_convert(self, raw, client):
pass
def __call__(self, raw, client):
return self.try_convert(raw, client)
class Field(object):
def __init__(self, value_type, alias=None, default=None):
self.src_name = alias
self.dst_name = None
if not hasattr(self, 'default'):
self.default = default
class Field(FieldType):
def __init__(self, typ, alias=None, default=None):
super(Field, self).__init__(typ)
self.deserializer = None
# Set names
self.src_name = alias
self.dst_name = None
if value_type:
self.deserializer = self.type_to_deserializer(value_type)
self.default = default
if isinstance(self.deserializer, Field):
self.default = self.deserializer.default
if isinstance(self.typ, FieldType):
self.default = self.typ.default
@property
def name(self):
return None
def set_name(self, name):
@name.setter
def name(self, name):
if not self.dst_name:
self.dst_name = name
@ -73,31 +56,68 @@ class Field(FieldType):
def try_convert(self, raw, client):
try:
return self.typ(raw, client)
return self.deserializer(raw, client)
except Exception as e:
six.raise_from(ConversionError(self, raw, e), e)
exc_info = sys.exc_info()
raise ConversionError(self, raw, e), exc_info[1], exc_info[2]
@staticmethod
def type_to_deserializer(typ):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ
elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw)
elif typ is None:
return lambda x, y: None
else:
return lambda raw, _: typ(raw)
class _Dict(FieldType):
@staticmethod
def serialize(value):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict()
else:
return value
def __call__(self, raw, client):
return self.try_convert(raw, client)
class DictField(Field):
default = HashMap
def __init__(self, typ, key=None):
super(_Dict, self).__init__(typ)
self.key = key
def __init__(self, key_type, value_type=None, **kwargs):
super(DictField, self).__init__(None, **kwargs)
self.key_de = self.type_to_deserializer(key_type)
self.value_de = self.type_to_deserializer(value_type or key_type)
def try_convert(self, raw, client):
if self.key:
converted = [self.typ(i, client) for i in raw]
return HashMap({getattr(i, self.key): i for i in converted})
else:
return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)})
return HashMap({
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
})
class _List(FieldType):
class ListField(Field):
default = list
def try_convert(self, raw, client):
return [self.typ(i, client) for i in raw]
return [self.deserializer(i, client) for i in raw]
class AutoDictField(Field):
default = HashMap
def __init__(self, value_type, key, **kwargs):
super(AutoDictField, self).__init__(None, **kwargs)
self.value_de = self.type_to_deserializer(value_type)
self.key = key
def try_convert(self, raw, client):
return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
})
def _make(typ, data, client):
@ -116,14 +136,6 @@ def enum(typ):
return _f
def listof(*args, **kwargs):
return _List(*args, **kwargs)
def dictof(*args, **kwargs):
return _Dict(*args, **kwargs)
def lazy_datetime(data):
if not data:
return property(lambda: None)
@ -201,7 +213,7 @@ class ModelMeta(type):
if not isinstance(v, Field):
continue
v.set_name(k)
v.name = k
fields[k] = v
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):

9
disco/types/channel.py

@ -5,7 +5,7 @@ from holster.enum import Enum
from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property, one_or_many, chunks
from disco.types.user import User
from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text
from disco.types.base import SlottedModel, Field, ListField, AutoDictField, snowflake, enum, text
from disco.types.permissions import Permissions, Permissible, PermissionValue
from disco.voice.client import VoiceClient
@ -111,15 +111,18 @@ class Channel(SlottedModel, Permissible):
last_message_id = Field(snowflake)
position = Field(int)
bitrate = Field(int)
recipients = Field(listof(User))
recipients = ListField(User)
type = Field(enum(ChannelType))
overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites')
overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites')
def __init__(self, *args, **kwargs):
super(Channel, self).__init__(*args, **kwargs)
self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self})
def __str__(self):
return '#{}'.format(self.name)
def get_permissions(self, user):
"""
Get the permissions a user has in the channel

40
disco/types/guild.py

@ -6,7 +6,9 @@ from disco.gateway.packets import OPCode
from disco.api.http import APIException
from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property
from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum
from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum
)
from disco.types.user import User, Presence
from disco.types.voice import VoiceState
from disco.types.channel import Channel
@ -45,7 +47,7 @@ class GuildEmoji(Emoji):
name = Field(text)
require_colons = Field(bool)
managed = Field(bool)
roles = Field(listof(snowflake))
roles = ListField(snowflake)
@cached_property
def guild(self):
@ -128,7 +130,7 @@ class GuildMember(SlottedModel):
mute = Field(bool)
deaf = Field(bool)
joined_at = Field(str)
roles = Field(listof(snowflake))
roles = ListField(snowflake)
def __str__(self):
return self.user.__str__()
@ -169,7 +171,10 @@ class GuildMember(SlottedModel):
nickname : Optional[str]
The nickname (or none to reset) to set.
"""
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '')
if self.client.state.me.id == self.user.id:
self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '')
else:
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '')
def add_role(self, role):
roles = self.roles + [role.id]
@ -196,6 +201,10 @@ class GuildMember(SlottedModel):
def guild(self):
return self.client.state.guilds.get(self.guild_id)
@cached_property
def permissions(self):
return self.guild.get_permissions(self)
class Guild(SlottedModel, Permissible):
"""
@ -252,14 +261,14 @@ class Guild(SlottedModel, Permissible):
embed_enabled = Field(bool)
verification_level = Field(enum(VerificationLevel))
mfa_level = Field(int)
features = Field(listof(str))
members = Field(dictof(GuildMember, key='id'))
channels = Field(dictof(Channel, key='id'))
roles = Field(dictof(Role, key='id'))
emojis = Field(dictof(GuildEmoji, key='id'))
voice_states = Field(dictof(VoiceState, key='session_id'))
features = ListField(str)
members = AutoDictField(GuildMember, 'id')
channels = AutoDictField(Channel, 'id')
roles = AutoDictField(Role, 'id')
emojis = AutoDictField(GuildEmoji, 'id')
voice_states = AutoDictField(VoiceState, 'session_id')
member_count = Field(int)
presences = Field(listof(Presence))
presences = ListField(Presence)
synced = Field(bool, default=False)
@ -272,7 +281,7 @@ class Guild(SlottedModel, Permissible):
self.attach(six.itervalues(self.emojis), {'guild_id': self.id})
self.attach(six.itervalues(self.voice_states), {'guild_id': self.id})
def get_permissions(self, user):
def get_permissions(self, member):
"""
Get the permissions a user has in this guild.
@ -281,10 +290,13 @@ class Guild(SlottedModel, Permissible):
:class:`disco.types.permissions.PermissionValue`
Computed permission value for the user.
"""
if self.owner_id == user.id:
if not isinstance(member, GuildMember):
member = self.get_member(member)
# Owner has all permissions
if self.owner_id == member.id:
return PermissionValue(Permissions.ADMINISTRATOR)
member = self.get_member(user)
value = PermissionValue(self.roles.get(self.id).permissions)
for role in map(self.roles.get, member.roles):

17
disco/types/message.py

@ -2,7 +2,10 @@ import re
from holster.enum import Enum
from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum
from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text,
lazy_datetime, enum
)
from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property
from disco.types.user import User
@ -109,7 +112,7 @@ class MessageEmbed(SlottedModel):
thumbnail = Field(MessageEmbedThumbnail)
video = Field(MessageEmbedVideo)
author = Field(MessageEmbedAuthor)
fields = Field(listof(MessageEmbedField))
fields = ListField(MessageEmbedField)
class MessageAttachment(SlottedModel):
@ -191,11 +194,11 @@ class Message(SlottedModel):
tts = Field(bool)
mention_everyone = Field(bool)
pinned = Field(bool)
mentions = Field(dictof(User, key='id'))
mention_roles = Field(listof(snowflake))
embeds = Field(listof(MessageEmbed))
attachments = Field(dictof(MessageAttachment, key='id'))
reactions = Field(listof(MessageReaction))
mentions = AutoDictField(User, 'id')
mention_roles = ListField(snowflake)
embeds = ListField(MessageEmbed)
attachments = AutoDictField(MessageAttachment, 'id')
reactions = ListField(MessageReaction)
def __str__(self):
return '<Message {} ({})>'.format(self.id, self.channel_id)

Loading…
Cancel
Save