Browse Source

Modeling improvements, couple other fixes

Modeling fields are now drastically better, no more dictof/listof
bullshit, we now properly have ListField/DictField/etc.

- Fix setting self nickname
pull/10/head
andrei 9 years ago
parent
commit
9891b900d6
  1. 3
      disco/api/client.py
  2. 1
      disco/api/http.py
  3. 24
      disco/gateway/events.py
  4. 126
      disco/types/base.py
  5. 9
      disco/types/channel.py
  6. 40
      disco/types/guild.py
  7. 17
      disco/types/message.py

3
disco/api/client.py

@ -196,6 +196,9 @@ class APIClient(LoggingClass):
def guilds_members_modify(self, guild, member, **kwargs): def guilds_members_modify(self, guild, member, **kwargs):
self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs) self.http(Routes.GUILDS_MEMBERS_MODIFY, dict(guild=guild, member=member), json=kwargs)
def guilds_members_me_nick(self, guild, nick):
self.http(Routes.GUILDS_MEMBERS_ME_NICK, dict(guild=guild), json={'nick': nick})
def guilds_members_kick(self, guild, member): def guilds_members_kick(self, guild, member):
self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member)) self.http(Routes.GUILDS_MEMBERS_KICK, dict(guild=guild, member=member))

1
disco/api/http.py

@ -71,6 +71,7 @@ class Routes(object):
GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members') GUILDS_MEMBERS_LIST = (HTTPMethod.GET, GUILDS + '/members')
GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}') GUILDS_MEMBERS_GET = (HTTPMethod.GET, GUILDS + '/members/{member}')
GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}') GUILDS_MEMBERS_MODIFY = (HTTPMethod.PATCH, GUILDS + '/members/{member}')
GUILDS_MEMBERS_ME_NICK = (HTTPMethod.PATCH, GUILDS + '/members/@me/nick')
GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}') GUILDS_MEMBERS_KICK = (HTTPMethod.DELETE, GUILDS + '/members/{member}')
GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans') GUILDS_BANS_LIST = (HTTPMethod.GET, GUILDS + '/bans')
GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}') GUILDS_BANS_CREATE = (HTTPMethod.PUT, GUILDS + '/bans/{user}')

24
disco/gateway/events.py

@ -9,7 +9,7 @@ from disco.types.message import Message, MessageReactionEmoji
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.guild import Guild, GuildMember, Role, Emoji from disco.types.guild import Guild, GuildMember, Role, Emoji
from disco.types.base import Model, ModelMeta, Field, snowflake, listof, lazy_datetime from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime
# Mapping of discords event name to our event classes # Mapping of discords event name to our event classes
EVENTS_MAP = {} EVENTS_MAP = {}
@ -89,7 +89,7 @@ def wraps_model(model, alias=None):
def deco(cls): def deco(cls):
cls._fields[alias] = Field(model) cls._fields[alias] = Field(model)
cls._fields[alias].set_name(alias) cls._fields[alias].name = alias
cls._wraps_model = (alias, model) cls._wraps_model = (alias, model)
cls._proxy = alias cls._proxy = alias
return cls return cls
@ -124,8 +124,8 @@ class Ready(GatewayEvent):
version = Field(int, alias='v') version = Field(int, alias='v')
session_id = Field(str) session_id = Field(str)
user = Field(User) user = Field(User)
guilds = Field(listof(Guild)) guilds = ListField(Guild)
private_channels = Field(listof(Channel)) private_channels = ListField(Guild)
class Resumed(GatewayEvent): class Resumed(GatewayEvent):
@ -293,7 +293,7 @@ class GuildEmojisUpdate(GatewayEvent):
The new set of emojis for the guild The new set of emojis for the guild
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
emojis = Field(listof(Emoji)) emojis = ListField(Emoji)
class GuildIntegrationsUpdate(GatewayEvent): class GuildIntegrationsUpdate(GatewayEvent):
@ -320,7 +320,7 @@ class GuildMembersChunk(GatewayEvent):
The chunk of members. The chunk of members.
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
members = Field(listof(GuildMember)) members = ListField(GuildMember)
@property @property
def guild(self): def guild(self):
@ -466,6 +466,14 @@ class MessageDelete(GatewayEvent):
id = Field(snowflake) id = Field(snowflake)
channel_id = Field(snowflake) channel_id = Field(snowflake)
@property
def channel(self):
return self.client.state.channels.get(self.channel_id)
@property
def guild(self):
return self.channel.guild
class MessageDeleteBulk(GatewayEvent): class MessageDeleteBulk(GatewayEvent):
""" """
@ -479,7 +487,7 @@ class MessageDeleteBulk(GatewayEvent):
List of messages being deleted in the channel. List of messages being deleted in the channel.
""" """
channel_id = Field(snowflake) channel_id = Field(snowflake)
ids = Field(listof(snowflake)) ids = ListField(snowflake)
@wraps_model(Presence) @wraps_model(Presence)
@ -497,7 +505,7 @@ class PresenceUpdate(GatewayEvent):
List of roles the user from the presence is part of. List of roles the user from the presence is part of.
""" """
guild_id = Field(snowflake) guild_id = Field(snowflake)
roles = Field(listof(snowflake)) roles = ListField(snowflake)
class TypingStart(GatewayEvent): class TypingStart(GatewayEvent):

126
disco/types/base.py

@ -1,4 +1,5 @@
import six import six
import sys
import gevent import gevent
import inspect import inspect
import functools import functools
@ -19,49 +20,31 @@ class ConversionError(Exception):
def __init__(self, field, raw, e): def __init__(self, field, raw, e):
super(ConversionError, self).__init__( super(ConversionError, self).__init__(
'Failed to convert `{}` (`{}`) to {}: {}'.format( 'Failed to convert `{}` (`{}`) to {}: {}'.format(
str(raw)[:144], field.src_name, field.typ, e)) str(raw)[:144], field.src_name, field.deserializer, e))
class FieldType(object): class Field(object):
def __init__(self, typ): def __init__(self, value_type, alias=None, default=None):
if isinstance(typ, FieldType) or inspect.isclass(typ) and issubclass(typ, Model): self.src_name = alias
self.typ = typ self.dst_name = None
elif isinstance(typ, BaseEnumMeta):
self.typ = lambda raw, _: typ.get(raw)
elif typ is None:
self.typ = lambda x, y: None
else:
self.typ = lambda raw, _: typ(raw)
def serialize(self, value):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict()
else:
return value
def try_convert(self, raw, client):
pass
def __call__(self, raw, client):
return self.try_convert(raw, client)
if not hasattr(self, 'default'):
self.default = default
class Field(FieldType): self.deserializer = None
def __init__(self, typ, alias=None, default=None):
super(Field, self).__init__(typ)
# Set names if value_type:
self.src_name = alias self.deserializer = self.type_to_deserializer(value_type)
self.dst_name = None
self.default = default if isinstance(self.deserializer, Field):
self.default = self.deserializer.default
if isinstance(self.typ, FieldType): @property
self.default = self.typ.default def name(self):
return None
def set_name(self, name): @name.setter
def name(self, name):
if not self.dst_name: if not self.dst_name:
self.dst_name = name self.dst_name = name
@ -73,31 +56,68 @@ class Field(FieldType):
def try_convert(self, raw, client): def try_convert(self, raw, client):
try: try:
return self.typ(raw, client) return self.deserializer(raw, client)
except Exception as e: except Exception as e:
six.raise_from(ConversionError(self, raw, e), e) exc_info = sys.exc_info()
raise ConversionError(self, raw, e), exc_info[1], exc_info[2]
@staticmethod
def type_to_deserializer(typ):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ
elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw)
elif typ is None:
return lambda x, y: None
else:
return lambda raw, _: typ(raw)
class _Dict(FieldType): @staticmethod
def serialize(value):
if isinstance(value, EnumAttr):
return value.value
elif isinstance(value, Model):
return value.to_dict()
else:
return value
def __call__(self, raw, client):
return self.try_convert(raw, client)
class DictField(Field):
default = HashMap default = HashMap
def __init__(self, typ, key=None): def __init__(self, key_type, value_type=None, **kwargs):
super(_Dict, self).__init__(typ) super(DictField, self).__init__(None, **kwargs)
self.key = key self.key_de = self.type_to_deserializer(key_type)
self.value_de = self.type_to_deserializer(value_type or key_type)
def try_convert(self, raw, client): def try_convert(self, raw, client):
if self.key: return HashMap({
converted = [self.typ(i, client) for i in raw] self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
return HashMap({getattr(i, self.key): i for i in converted}) })
else:
return HashMap({k: self.typ(v, client) for k, v in six.iteritems(raw)})
class _List(FieldType): class ListField(Field):
default = list default = list
def try_convert(self, raw, client): def try_convert(self, raw, client):
return [self.typ(i, client) for i in raw] return [self.deserializer(i, client) for i in raw]
class AutoDictField(Field):
default = HashMap
def __init__(self, value_type, key, **kwargs):
super(AutoDictField, self).__init__(None, **kwargs)
self.value_de = self.type_to_deserializer(value_type)
self.key = key
def try_convert(self, raw, client):
return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
})
def _make(typ, data, client): def _make(typ, data, client):
@ -116,14 +136,6 @@ def enum(typ):
return _f return _f
def listof(*args, **kwargs):
return _List(*args, **kwargs)
def dictof(*args, **kwargs):
return _Dict(*args, **kwargs)
def lazy_datetime(data): def lazy_datetime(data):
if not data: if not data:
return property(lambda: None) return property(lambda: None)
@ -201,7 +213,7 @@ class ModelMeta(type):
if not isinstance(v, Field): if not isinstance(v, Field):
continue continue
v.set_name(k) v.name = k
fields[k] = v fields[k] = v
if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)): if SlottedModel and any(map(lambda k: issubclass(k, SlottedModel), parents)):

9
disco/types/channel.py

@ -5,7 +5,7 @@ from holster.enum import Enum
from disco.util.snowflake import to_snowflake from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property, one_or_many, chunks from disco.util.functional import cached_property, one_or_many, chunks
from disco.types.user import User from disco.types.user import User
from disco.types.base import SlottedModel, Field, snowflake, enum, listof, dictof, text from disco.types.base import SlottedModel, Field, ListField, AutoDictField, snowflake, enum, text
from disco.types.permissions import Permissions, Permissible, PermissionValue from disco.types.permissions import Permissions, Permissible, PermissionValue
from disco.voice.client import VoiceClient from disco.voice.client import VoiceClient
@ -111,15 +111,18 @@ class Channel(SlottedModel, Permissible):
last_message_id = Field(snowflake) last_message_id = Field(snowflake)
position = Field(int) position = Field(int)
bitrate = Field(int) bitrate = Field(int)
recipients = Field(listof(User)) recipients = ListField(User)
type = Field(enum(ChannelType)) type = Field(enum(ChannelType))
overwrites = Field(dictof(PermissionOverwrite, key='id'), alias='permission_overwrites') overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites')
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Channel, self).__init__(*args, **kwargs) super(Channel, self).__init__(*args, **kwargs)
self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self})
def __str__(self):
return '#{}'.format(self.name)
def get_permissions(self, user): def get_permissions(self, user):
""" """
Get the permissions a user has in the channel Get the permissions a user has in the channel

40
disco/types/guild.py

@ -6,7 +6,9 @@ from disco.gateway.packets import OPCode
from disco.api.http import APIException from disco.api.http import APIException
from disco.util.snowflake import to_snowflake from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.base import SlottedModel, Field, snowflake, listof, dictof, text, binary, enum from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text, binary, enum
)
from disco.types.user import User, Presence from disco.types.user import User, Presence
from disco.types.voice import VoiceState from disco.types.voice import VoiceState
from disco.types.channel import Channel from disco.types.channel import Channel
@ -45,7 +47,7 @@ class GuildEmoji(Emoji):
name = Field(text) name = Field(text)
require_colons = Field(bool) require_colons = Field(bool)
managed = Field(bool) managed = Field(bool)
roles = Field(listof(snowflake)) roles = ListField(snowflake)
@cached_property @cached_property
def guild(self): def guild(self):
@ -128,7 +130,7 @@ class GuildMember(SlottedModel):
mute = Field(bool) mute = Field(bool)
deaf = Field(bool) deaf = Field(bool)
joined_at = Field(str) joined_at = Field(str)
roles = Field(listof(snowflake)) roles = ListField(snowflake)
def __str__(self): def __str__(self):
return self.user.__str__() return self.user.__str__()
@ -169,7 +171,10 @@ class GuildMember(SlottedModel):
nickname : Optional[str] nickname : Optional[str]
The nickname (or none to reset) to set. The nickname (or none to reset) to set.
""" """
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '') if self.client.state.me.id == self.user.id:
self.client.api.guilds_members_me_nick(self.guild.id, nick=nickname or '')
else:
self.client.api.guilds_members_modify(self.guild.id, self.user.id, nick=nickname or '')
def add_role(self, role): def add_role(self, role):
roles = self.roles + [role.id] roles = self.roles + [role.id]
@ -196,6 +201,10 @@ class GuildMember(SlottedModel):
def guild(self): def guild(self):
return self.client.state.guilds.get(self.guild_id) return self.client.state.guilds.get(self.guild_id)
@cached_property
def permissions(self):
return self.guild.get_permissions(self)
class Guild(SlottedModel, Permissible): class Guild(SlottedModel, Permissible):
""" """
@ -252,14 +261,14 @@ class Guild(SlottedModel, Permissible):
embed_enabled = Field(bool) embed_enabled = Field(bool)
verification_level = Field(enum(VerificationLevel)) verification_level = Field(enum(VerificationLevel))
mfa_level = Field(int) mfa_level = Field(int)
features = Field(listof(str)) features = ListField(str)
members = Field(dictof(GuildMember, key='id')) members = AutoDictField(GuildMember, 'id')
channels = Field(dictof(Channel, key='id')) channels = AutoDictField(Channel, 'id')
roles = Field(dictof(Role, key='id')) roles = AutoDictField(Role, 'id')
emojis = Field(dictof(GuildEmoji, key='id')) emojis = AutoDictField(GuildEmoji, 'id')
voice_states = Field(dictof(VoiceState, key='session_id')) voice_states = AutoDictField(VoiceState, 'session_id')
member_count = Field(int) member_count = Field(int)
presences = Field(listof(Presence)) presences = ListField(Presence)
synced = Field(bool, default=False) synced = Field(bool, default=False)
@ -272,7 +281,7 @@ class Guild(SlottedModel, Permissible):
self.attach(six.itervalues(self.emojis), {'guild_id': self.id}) self.attach(six.itervalues(self.emojis), {'guild_id': self.id})
self.attach(six.itervalues(self.voice_states), {'guild_id': self.id}) self.attach(six.itervalues(self.voice_states), {'guild_id': self.id})
def get_permissions(self, user): def get_permissions(self, member):
""" """
Get the permissions a user has in this guild. Get the permissions a user has in this guild.
@ -281,10 +290,13 @@ class Guild(SlottedModel, Permissible):
:class:`disco.types.permissions.PermissionValue` :class:`disco.types.permissions.PermissionValue`
Computed permission value for the user. Computed permission value for the user.
""" """
if self.owner_id == user.id: if not isinstance(member, GuildMember):
member = self.get_member(member)
# Owner has all permissions
if self.owner_id == member.id:
return PermissionValue(Permissions.ADMINISTRATOR) return PermissionValue(Permissions.ADMINISTRATOR)
member = self.get_member(user)
value = PermissionValue(self.roles.get(self.id).permissions) value = PermissionValue(self.roles.get(self.id).permissions)
for role in map(self.roles.get, member.roles): for role in map(self.roles.get, member.roles):

17
disco/types/message.py

@ -2,7 +2,10 @@ import re
from holster.enum import Enum from holster.enum import Enum
from disco.types.base import SlottedModel, Field, snowflake, text, lazy_datetime, dictof, listof, enum from disco.types.base import (
SlottedModel, Field, ListField, AutoDictField, snowflake, text,
lazy_datetime, enum
)
from disco.util.snowflake import to_snowflake from disco.util.snowflake import to_snowflake
from disco.util.functional import cached_property from disco.util.functional import cached_property
from disco.types.user import User from disco.types.user import User
@ -109,7 +112,7 @@ class MessageEmbed(SlottedModel):
thumbnail = Field(MessageEmbedThumbnail) thumbnail = Field(MessageEmbedThumbnail)
video = Field(MessageEmbedVideo) video = Field(MessageEmbedVideo)
author = Field(MessageEmbedAuthor) author = Field(MessageEmbedAuthor)
fields = Field(listof(MessageEmbedField)) fields = ListField(MessageEmbedField)
class MessageAttachment(SlottedModel): class MessageAttachment(SlottedModel):
@ -191,11 +194,11 @@ class Message(SlottedModel):
tts = Field(bool) tts = Field(bool)
mention_everyone = Field(bool) mention_everyone = Field(bool)
pinned = Field(bool) pinned = Field(bool)
mentions = Field(dictof(User, key='id')) mentions = AutoDictField(User, 'id')
mention_roles = Field(listof(snowflake)) mention_roles = ListField(snowflake)
embeds = Field(listof(MessageEmbed)) embeds = ListField(MessageEmbed)
attachments = Field(dictof(MessageAttachment, key='id')) attachments = AutoDictField(MessageAttachment, 'id')
reactions = Field(listof(MessageReaction)) reactions = ListField(MessageReaction)
def __str__(self): def __str__(self):
return '<Message {} ({})>'.format(self.id, self.channel_id) return '<Message {} ({})>'.format(self.id, self.channel_id)

Loading…
Cancel
Save