diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 6430d84..f74431e 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -4,12 +4,12 @@ import inflection import six from disco.types.user import User, Presence -from disco.types.channel import Channel +from disco.types.channel import Channel, PermissionOverwrite from disco.types.message import Message, MessageReactionEmoji from disco.types.voice import VoiceState from disco.types.guild import Guild, GuildMember, Role, GuildEmoji -from disco.types.base import Model, ModelMeta, Field, ListField, snowflake, lazy_datetime +from disco.types.base import Model, ModelMeta, Field, ListField, AutoDictField, snowflake, lazy_datetime # Mapping of discords event name to our event classes EVENTS_MAP = {} @@ -217,6 +217,7 @@ class ChannelUpdate(ChannelCreate): channel : :class:`disco.types.channel.Channel` The channel which was updated. """ + overwrites = AutoDictField(PermissionOverwrite, 'id', alias='permission_overwrites') @wraps_model(Channel) diff --git a/disco/state.py b/disco/state.py index 86bbcb9..ddca31c 100644 --- a/disco/state.py +++ b/disco/state.py @@ -5,6 +5,7 @@ import inflection from collections import deque, namedtuple from gevent.event import Event +from disco.types.base import UNSET from disco.util.config import Config from disco.util.hashmap import HashMap, DefaultHashMap @@ -211,6 +212,10 @@ class State(object): if event.channel.id in self.channels: self.channels[event.channel.id].update(event.channel) + if event.overwrites is not UNSET: + self.channels[event.channel.id].overwrites = event.overwrites + self.channels[event.channel.id].after_load() + def on_channel_delete(self, event): if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels: del event.channel.guild.channels[event.channel.id] diff --git a/disco/types/base.py b/disco/types/base.py index 0cf0646..d12de6e 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -40,11 +40,12 @@ class ConversionError(Exception): class Field(object): - def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, **kwargs): + def __init__(self, value_type, alias=None, default=None, create=True, ignore_dump=None, cast=None, **kwargs): # TODO: fix default bullshit self.src_name = alias self.dst_name = None self.ignore_dump = ignore_dump or [] + self.cast = cast self.metadata = kwargs if default is not None: @@ -101,6 +102,8 @@ class Field(object): elif isinstance(value, Model): return value.to_dict(ignore=(inst.ignore_dump if inst else [])) else: + if inst and inst.cast: + return inst.cast(value) return value def __call__(self, raw, client): diff --git a/disco/types/channel.py b/disco/types/channel.py index 4a26c05..664fd94 100644 --- a/disco/types/channel.py +++ b/disco/types/channel.py @@ -48,8 +48,8 @@ class PermissionOverwrite(ChannelSubType): """ id = Field(snowflake) type = Field(enum(PermissionOverwriteType)) - allow = Field(PermissionValue) - deny = Field(PermissionValue) + allow = Field(PermissionValue, cast=int) + deny = Field(PermissionValue, cast=int) channel_id = Field(snowflake) @@ -67,6 +67,13 @@ class PermissionOverwrite(ChannelSubType): channel_id=channel.id ).save() + @property + def compiled(self): + value = PermissionValue() + value -= self.deny + value += self.allow + return value + def save(self): self.client.api.channels_permissions_modify(self.channel_id, self.id, @@ -117,7 +124,10 @@ class Channel(SlottedModel, Permissible): def __init__(self, *args, **kwargs): super(Channel, self).__init__(*args, **kwargs) + self.after_load() + def after_load(self): + # TODO: hackfix self.attach(six.itervalues(self.overwrites), {'channel_id': self.id, 'channel': self}) def __str__(self): diff --git a/disco/types/permissions.py b/disco/types/permissions.py index ff43145..6e4d9f3 100644 --- a/disco/types/permissions.py +++ b/disco/types/permissions.py @@ -90,9 +90,12 @@ class PermissionValue(object): else: self.value &= ~Permissions[name].value + def __int__(self): + return self.value + def to_dict(self): return { - k: getattr(self, k) for k in Permissions.attrs + k: getattr(self, k) for k in Permissions.keys_ } @classmethod