Browse Source

Add typing for flags

pull/6669/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
83fe98c20d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 74
      discord/flags.py

74
discord/flags.py

@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload
from .enums import UserFlags from .enums import UserFlags
__all__ = ( __all__ = (
@ -32,17 +36,28 @@ __all__ = (
'MemberCacheFlags', 'MemberCacheFlags',
) )
class flag_value: FV = TypeVar('FV', bound='flag_value')
def __init__(self, func): BF = TypeVar('BF', bound='BaseFlags')
class flag_value(Generic[BF]):
def __init__(self, func: Callable[[Any], int]):
self.flag = func(None) self.flag = func(None)
self.__doc__ = func.__doc__ self.__doc__ = func.__doc__
def __get__(self, instance, owner): @overload
def __get__(self: FV, instance: None, owner: Type[BF]) -> FV:
...
@overload
def __get__(self, instance: BF, owner: Type[BF]) -> bool:
...
def __get__(self, instance: Optional[BF], owner: Type[BF]) -> Any:
if instance is None: if instance is None:
return self return self
return instance._has_flag(self.flag) return instance._has_flag(self.flag)
def __set__(self, instance, value): def __set__(self, instance: BF, value: bool) -> None:
instance._set_flag(self.flag, value) instance._set_flag(self.flag, value)
def __repr__(self): def __repr__(self):
@ -51,8 +66,8 @@ class flag_value:
class alias_flag_value(flag_value): class alias_flag_value(flag_value):
pass pass
def fill_with_flags(*, inverted=False): def fill_with_flags(*, inverted: bool = False):
def decorator(cls): def decorator(cls: Type[BF]):
cls.VALID_FLAGS = { cls.VALID_FLAGS = {
name: value.flag name: value.flag
for name, value in cls.__dict__.items() for name, value in cls.__dict__.items()
@ -70,9 +85,14 @@ def fill_with_flags(*, inverted=False):
# n.b. flags must inherit from this and use the decorator above # n.b. flags must inherit from this and use the decorator above
class BaseFlags: class BaseFlags:
VALID_FLAGS: ClassVar[Dict[str, int]]
DEFAULT_VALUE: ClassVar[int]
value: int
__slots__ = ('value',) __slots__ = ('value',)
def __init__(self, **kwargs): def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
@ -85,19 +105,19 @@ class BaseFlags:
self.value = value self.value = value
return self return self
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.value == other.value return isinstance(other, self.__class__) and self.value == other.value
def __ne__(self, other): def __ne__(self, other: Any) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self): def __hash__(self) -> int:
return hash(self.value) return hash(self.value)
def __repr__(self): def __repr__(self) -> str:
return f'<{self.__class__.__name__} value={self.value}>' return f'<{self.__class__.__name__} value={self.value}>'
def __iter__(self): def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name, value in self.__class__.__dict__.items(): for name, value in self.__class__.__dict__.items():
if isinstance(value, alias_flag_value): if isinstance(value, alias_flag_value):
continue continue
@ -105,10 +125,10 @@ class BaseFlags:
if isinstance(value, flag_value): if isinstance(value, flag_value):
yield (name, self._has_flag(value.flag)) yield (name, self._has_flag(value.flag))
def _has_flag(self, o): def _has_flag(self, o: int) -> bool:
return (self.value & o) == o return (self.value & o) == o
def _set_flag(self, o, toggle): def _set_flag(self, o: int, toggle: bool) -> None:
if toggle is True: if toggle is True:
self.value |= o self.value |= o
elif toggle is False: elif toggle is False:
@ -150,6 +170,7 @@ class SystemChannelFlags(BaseFlags):
representing the currently available flags. You should query representing the currently available flags. You should query
flags via the properties rather than using this raw value. flags via the properties rather than using this raw value.
""" """
__slots__ = () __slots__ = ()
# For some reason the flags for system channels are "inverted" # For some reason the flags for system channels are "inverted"
@ -157,10 +178,10 @@ class SystemChannelFlags(BaseFlags):
# Since this is counter-intuitive from an API perspective and annoying # Since this is counter-intuitive from an API perspective and annoying
# these will be inverted automatically # these will be inverted automatically
def _has_flag(self, o): def _has_flag(self, o: int) -> bool:
return (self.value & o) != o return (self.value & o) != o
def _set_flag(self, o, toggle): def _set_flag(self, o: int, toggle: bool) -> None:
if toggle is True: if toggle is True:
self.value &= ~o self.value &= ~o
elif toggle is False: elif toggle is False:
@ -210,6 +231,7 @@ class MessageFlags(BaseFlags):
representing the currently available flags. You should query representing the currently available flags. You should query
flags via the properties rather than using this raw value. flags via the properties rather than using this raw value.
""" """
__slots__ = () __slots__ = ()
@flag_value @flag_value
@ -346,7 +368,7 @@ class PublicUserFlags(BaseFlags):
""" """
return UserFlags.verified_bot_developer.value return UserFlags.verified_bot_developer.value
def all(self): def all(self) -> List[UserFlags]:
"""List[:class:`UserFlags`]: Returns all public flags the user has.""" """List[:class:`UserFlags`]: Returns all public flags the user has."""
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]
@ -393,7 +415,7 @@ class Intents(BaseFlags):
__slots__ = () __slots__ = ()
def __init__(self, **kwargs): def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
@ -401,7 +423,7 @@ class Intents(BaseFlags):
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
def all(cls): def all(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled.""" """A factory method that creates a :class:`Intents` with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length() bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1 value = (1 << bits) - 1
@ -410,14 +432,14 @@ class Intents(BaseFlags):
return self return self
@classmethod @classmethod
def none(cls): def none(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything disabled.""" """A factory method that creates a :class:`Intents` with everything disabled."""
self = cls.__new__(cls) self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
return self return self
@classmethod @classmethod
def default(cls): def default(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled """A factory method that creates a :class:`Intents` with everything enabled
except :attr:`presences` and :attr:`members`. except :attr:`presences` and :attr:`members`.
""" """
@ -825,7 +847,7 @@ class MemberCacheFlags(BaseFlags):
__slots__ = () __slots__ = ()
def __init__(self, **kwargs): def __init__(self, **kwargs: bool):
bits = max(self.VALID_FLAGS.values()).bit_length() bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1 self.value = (1 << bits) - 1
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -834,7 +856,7 @@ class MemberCacheFlags(BaseFlags):
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
def all(cls): def all(cls: Type[MemberCacheFlags]) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with everything enabled.""" """A factory method that creates a :class:`MemberCacheFlags` with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length() bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1 value = (1 << bits) - 1
@ -843,7 +865,7 @@ class MemberCacheFlags(BaseFlags):
return self return self
@classmethod @classmethod
def none(cls): def none(cls: Type[MemberCacheFlags]) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` with everything disabled.""" """A factory method that creates a :class:`MemberCacheFlags` with everything disabled."""
self = cls.__new__(cls) self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
@ -886,7 +908,7 @@ class MemberCacheFlags(BaseFlags):
return 4 return 4
@classmethod @classmethod
def from_intents(cls, intents): def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` based on """A factory method that creates a :class:`MemberCacheFlags` based on
the currently selected :class:`Intents`. the currently selected :class:`Intents`.
@ -914,7 +936,7 @@ class MemberCacheFlags(BaseFlags):
return self return self
def _verify_intents(self, intents): def _verify_intents(self, intents: Intents):
if self.online and not intents.presences: if self.online and not intents.presences:
raise ValueError('MemberCacheFlags.online requires Intents.presences enabled') raise ValueError('MemberCacheFlags.online requires Intents.presences enabled')

Loading…
Cancel
Save