From f7687e0a684023ab91e5bb97d1f1cadbaf8e0413 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Fri, 20 Dec 2019 21:18:12 -0500 Subject: [PATCH] Clean up flag code significantly. This also fixes the False setting bug. --- discord/flags.py | 168 +++++++++++++++++++++-------------------------- 1 file changed, 74 insertions(+), 94 deletions(-) diff --git a/discord/flags.py b/discord/flags.py index 1c3ce536a..dc89c46d1 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -29,7 +29,7 @@ __all__ = ( 'MessageFlags', ) -class _flag_descriptor: +class flag_value: def __init__(self, func): self.flag = func(None) self.__doc__ = func.__doc__ @@ -40,19 +40,70 @@ class _flag_descriptor: def __set__(self, instance, value): instance._set_flag(self.flag, value) -def fill_with_flags(cls): - cls.VALID_FLAGS = { - name: value.flag - for name, value in cls.__dict__.items() - if isinstance(value, _flag_descriptor) - } +def fill_with_flags(*, inverted=False): + def decorator(cls): + cls.VALID_FLAGS = { + name: value.flag + for name, value in cls.__dict__.items() + if isinstance(value, flag_value) + } + + if inverted: + max_bits = max(cls.VALID_FLAGS.values()).bit_length() + cls.DEFAULT_VALUE = -1 + (2 ** max_bits) + else: + cls.DEFAULT_VALUE = 0 + + return cls + return decorator + +# n.b. flags must inherit from this and use the decorator above +class BaseFlags: + __slots__ = ('value',) + + def __init__(self, **kwargs): + self.value = self.DEFAULT_VALUE + for key, value in kwargs.items(): + if key not in self.VALID_FLAGS: + raise TypeError('%r is not a valid flag name.' % key) + setattr(self, key, value) + + @classmethod + def _from_value(cls, value): + self = cls.__new__(cls) + self.value = value + return self + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.value == other.value + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.value) + + def __repr__(self): + return '<%s value=%s>' % (self.__class__.__name__, self.value) - max_bits = max(cls.VALID_FLAGS.values()).bit_length() - cls.ALL_OFF_VALUE = -1 + (2 ** max_bits) - return cls + def __iter__(self): + for name, value in self.__class__.__dict__.items(): + if isinstance(value, flag_value): + yield (name, self._has_flag(value.flag)) + + def _has_flag(self, o): + return (self.value & o) == o + + def _set_flag(self, o, toggle): + if toggle is True: + self.value |= o + elif toggle is False: + self.value &= ~o + else: + raise TypeError('Value to set for %s must be a bool.' % self.__class__.__name__) -@fill_with_flags -class SystemChannelFlags: +@fill_with_flags(inverted=True) +class SystemChannelFlags(BaseFlags): r"""Wraps up a Discord system channel flag value. Similar to :class:`Permissions`\, the properties provided are two way. @@ -85,37 +136,7 @@ class SystemChannelFlags: representing the currently available flags. You should query flags via the properties rather than using this raw value. """ - __slots__ = ('value',) - - def __init__(self, **kwargs): - self.value = self.ALL_OFF_VALUE - for key, value in kwargs.items(): - if key not in self.VALID_FLAGS: - raise TypeError('%r is not a valid flag name.' % key) - setattr(self, key, value) - - @classmethod - def _from_value(cls, value): - self = cls.__new__(cls) - self.value = value - return self - - def __eq__(self, other): - return isinstance(other, SystemChannelFlags) and self.value == other.value - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self.value) - - def __repr__(self): - return '' % self.value - - def __iter__(self): - for name, value in self.__class__.__dict__.items(): - if isinstance(value, _flag_descriptor): - yield (name, self._has_flag(value.flag)) + __slots__ = () # For some reason the flags for system channels are "inverted" # ergo, if they're set then it means "suppress" (off in the GUI toggle) @@ -133,19 +154,19 @@ class SystemChannelFlags: else: raise TypeError('Value to set for SystemChannelFlags must be a bool.') - @_flag_descriptor + @flag_value def join_notifications(self): """:class:`bool`: Returns ``True`` if the system channel is used for member join notifications.""" return 1 - @_flag_descriptor + @flag_value def premium_subscriptions(self): """:class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications.""" return 2 -@fill_with_flags -class MessageFlags: +@fill_with_flags() +class MessageFlags(BaseFlags): r"""Wraps up a Discord Message flag value. See :class:`SystemChannelFlags`. @@ -173,65 +194,24 @@ class MessageFlags: representing the currently available flags. You should query flags via the properties rather than using this raw value. """ - __slots__ = ('value',) + __slots__ = () - def __init__(self, **kwargs): - self.value = 0 - for key, value in kwargs.items(): - if key not in self.VALID_FLAGS: - raise TypeError('%r is not a valid flag name.' % key) - setattr(self, key, value) - - @classmethod - def _from_value(cls, value): - self = cls.__new__(cls) - self.value = value - return self - - def __eq__(self, other): - return isinstance(other, MessageFlags) and self.value == other.value - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self.value) - - def __repr__(self): - return '' % self.value - - def __iter__(self): - for name, value in self.__class__.__dict__.items(): - if isinstance(value, _flag_descriptor): - yield (name, self._has_flag(value.flag)) - - def _has_flag(self, o): - return (self.value & o) == o - - def _set_flag(self, o, toggle): - if toggle is True: - self.value |= o - elif toggle is False: - self.value &= o - else: - raise TypeError('Value to set for MessageFlags must be a bool.') - - @_flag_descriptor + @flag_value def crossposted(self): """:class:`bool`: Returns ``True`` if the message is the original crossposted message.""" return 1 - @_flag_descriptor + @flag_value def is_crossposted(self): """:class:`bool`: Returns ``True`` if the message was crossposted from another channel.""" return 2 - @_flag_descriptor + @flag_value def suppress_embeds(self): """:class:`bool`: Returns ``True`` if the message's embeds have been suppressed.""" return 4 - - @_flag_descriptor + + @flag_value def source_message_deleted(self): """:class:`bool`: Returns ``True`` if the source message for this crosspost has been deleted.""" return 8