Browse Source

Clean up flag code significantly.

This also fixes the False setting bug.
pull/2484/head
Rapptz 5 years ago
parent
commit
f7687e0a68
  1. 168
      discord/flags.py

168
discord/flags.py

@ -29,7 +29,7 @@ __all__ = (
'MessageFlags', 'MessageFlags',
) )
class _flag_descriptor: class flag_value:
def __init__(self, func): def __init__(self, func):
self.flag = func(None) self.flag = func(None)
self.__doc__ = func.__doc__ self.__doc__ = func.__doc__
@ -40,19 +40,70 @@ class _flag_descriptor:
def __set__(self, instance, value): def __set__(self, instance, value):
instance._set_flag(self.flag, value) instance._set_flag(self.flag, value)
def fill_with_flags(cls): def fill_with_flags(*, inverted=False):
cls.VALID_FLAGS = { def decorator(cls):
name: value.flag cls.VALID_FLAGS = {
for name, value in cls.__dict__.items() name: value.flag
if isinstance(value, _flag_descriptor) for name, value in cls.__dict__.items()
} if isinstance(value, flag_value)
}
if inverted:
max_bits = max(cls.VALID_FLAGS.values()).bit_length()
cls.DEFAULT_VALUE = -1 + (2 ** max_bits)
else:
cls.DEFAULT_VALUE = 0
return cls
return decorator
# n.b. flags must inherit from this and use the decorator above
class BaseFlags:
__slots__ = ('value',)
def __init__(self, **kwargs):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError('%r is not a valid flag name.' % key)
setattr(self, key, value)
@classmethod
def _from_value(cls, value):
self = cls.__new__(cls)
self.value = value
return self
def __eq__(self, other):
return isinstance(other, self.__class__) and self.value == other.value
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.value)
def __repr__(self):
return '<%s value=%s>' % (self.__class__.__name__, self.value)
max_bits = max(cls.VALID_FLAGS.values()).bit_length() def __iter__(self):
cls.ALL_OFF_VALUE = -1 + (2 ** max_bits) for name, value in self.__class__.__dict__.items():
return cls if isinstance(value, flag_value):
yield (name, self._has_flag(value.flag))
def _has_flag(self, o):
return (self.value & o) == o
def _set_flag(self, o, toggle):
if toggle is True:
self.value |= o
elif toggle is False:
self.value &= ~o
else:
raise TypeError('Value to set for %s must be a bool.' % self.__class__.__name__)
@fill_with_flags @fill_with_flags(inverted=True)
class SystemChannelFlags: class SystemChannelFlags(BaseFlags):
r"""Wraps up a Discord system channel flag value. r"""Wraps up a Discord system channel flag value.
Similar to :class:`Permissions`\, the properties provided are two way. Similar to :class:`Permissions`\, the properties provided are two way.
@ -85,37 +136,7 @@ class SystemChannelFlags:
representing the currently available flags. You should query representing the currently available flags. You should query
flags via the properties rather than using this raw value. flags via the properties rather than using this raw value.
""" """
__slots__ = ('value',) __slots__ = ()
def __init__(self, **kwargs):
self.value = self.ALL_OFF_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError('%r is not a valid flag name.' % key)
setattr(self, key, value)
@classmethod
def _from_value(cls, value):
self = cls.__new__(cls)
self.value = value
return self
def __eq__(self, other):
return isinstance(other, SystemChannelFlags) and self.value == other.value
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.value)
def __repr__(self):
return '<SystemChannelFlags value=%s>' % self.value
def __iter__(self):
for name, value in self.__class__.__dict__.items():
if isinstance(value, _flag_descriptor):
yield (name, self._has_flag(value.flag))
# For some reason the flags for system channels are "inverted" # For some reason the flags for system channels are "inverted"
# ergo, if they're set then it means "suppress" (off in the GUI toggle) # ergo, if they're set then it means "suppress" (off in the GUI toggle)
@ -133,19 +154,19 @@ class SystemChannelFlags:
else: else:
raise TypeError('Value to set for SystemChannelFlags must be a bool.') raise TypeError('Value to set for SystemChannelFlags must be a bool.')
@_flag_descriptor @flag_value
def join_notifications(self): def join_notifications(self):
""":class:`bool`: Returns ``True`` if the system channel is used for member join notifications.""" """:class:`bool`: Returns ``True`` if the system channel is used for member join notifications."""
return 1 return 1
@_flag_descriptor @flag_value
def premium_subscriptions(self): def premium_subscriptions(self):
""":class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications.""" """:class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
return 2 return 2
@fill_with_flags @fill_with_flags()
class MessageFlags: class MessageFlags(BaseFlags):
r"""Wraps up a Discord Message flag value. r"""Wraps up a Discord Message flag value.
See :class:`SystemChannelFlags`. See :class:`SystemChannelFlags`.
@ -173,65 +194,24 @@ class MessageFlags:
representing the currently available flags. You should query representing the currently available flags. You should query
flags via the properties rather than using this raw value. flags via the properties rather than using this raw value.
""" """
__slots__ = ('value',) __slots__ = ()
def __init__(self, **kwargs): @flag_value
self.value = 0
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError('%r is not a valid flag name.' % key)
setattr(self, key, value)
@classmethod
def _from_value(cls, value):
self = cls.__new__(cls)
self.value = value
return self
def __eq__(self, other):
return isinstance(other, MessageFlags) and self.value == other.value
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.value)
def __repr__(self):
return '<MessageFlags value=%s>' % self.value
def __iter__(self):
for name, value in self.__class__.__dict__.items():
if isinstance(value, _flag_descriptor):
yield (name, self._has_flag(value.flag))
def _has_flag(self, o):
return (self.value & o) == o
def _set_flag(self, o, toggle):
if toggle is True:
self.value |= o
elif toggle is False:
self.value &= o
else:
raise TypeError('Value to set for MessageFlags must be a bool.')
@_flag_descriptor
def crossposted(self): def crossposted(self):
""":class:`bool`: Returns ``True`` if the message is the original crossposted message.""" """:class:`bool`: Returns ``True`` if the message is the original crossposted message."""
return 1 return 1
@_flag_descriptor @flag_value
def is_crossposted(self): def is_crossposted(self):
""":class:`bool`: Returns ``True`` if the message was crossposted from another channel.""" """:class:`bool`: Returns ``True`` if the message was crossposted from another channel."""
return 2 return 2
@_flag_descriptor @flag_value
def suppress_embeds(self): def suppress_embeds(self):
""":class:`bool`: Returns ``True`` if the message's embeds have been suppressed.""" """:class:`bool`: Returns ``True`` if the message's embeds have been suppressed."""
return 4 return 4
@_flag_descriptor @flag_value
def source_message_deleted(self): def source_message_deleted(self):
""":class:`bool`: Returns ``True`` if the source message for this crosspost has been deleted.""" """:class:`bool`: Returns ``True`` if the source message for this crosspost has been deleted."""
return 8 return 8

Loading…
Cancel
Save