From 83fe98c20d2bedbc10fdd10d278e171916887ba9 Mon Sep 17 00:00:00 2001 From: Nadir Chowdhury Date: Wed, 7 Apr 2021 12:55:55 +0100 Subject: [PATCH] Add typing for flags --- discord/flags.py | 74 +++++++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/discord/flags.py b/discord/flags.py index d7ceb9076..af75132d8 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload + from .enums import UserFlags __all__ = ( @@ -32,17 +36,28 @@ __all__ = ( 'MemberCacheFlags', ) -class flag_value: - def __init__(self, func): +FV = TypeVar('FV', bound='flag_value') +BF = TypeVar('BF', bound='BaseFlags') + +class flag_value(Generic[BF]): + def __init__(self, func: Callable[[Any], int]): self.flag = func(None) self.__doc__ = func.__doc__ - def __get__(self, instance, owner): + @overload + def __get__(self: FV, instance: None, owner: Type[BF]) -> FV: + ... + + @overload + def __get__(self, instance: BF, owner: Type[BF]) -> bool: + ... + + def __get__(self, instance: Optional[BF], owner: Type[BF]) -> Any: if instance is None: return self return instance._has_flag(self.flag) - def __set__(self, instance, value): + def __set__(self, instance: BF, value: bool) -> None: instance._set_flag(self.flag, value) def __repr__(self): @@ -51,8 +66,8 @@ class flag_value: class alias_flag_value(flag_value): pass -def fill_with_flags(*, inverted=False): - def decorator(cls): +def fill_with_flags(*, inverted: bool = False): + def decorator(cls: Type[BF]): cls.VALID_FLAGS = { name: value.flag for name, value in cls.__dict__.items() @@ -70,9 +85,14 @@ def fill_with_flags(*, inverted=False): # n.b. flags must inherit from this and use the decorator above class BaseFlags: + VALID_FLAGS: ClassVar[Dict[str, int]] + DEFAULT_VALUE: ClassVar[int] + + value: int + __slots__ = ('value',) - def __init__(self, **kwargs): + def __init__(self, **kwargs: bool): self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: @@ -85,19 +105,19 @@ class BaseFlags: self.value = value return self - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) and self.value == other.value - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.value) - def __repr__(self): + def __repr__(self) -> str: return f'<{self.__class__.__name__} value={self.value}>' - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[str, bool]]: for name, value in self.__class__.__dict__.items(): if isinstance(value, alias_flag_value): continue @@ -105,10 +125,10 @@ class BaseFlags: if isinstance(value, flag_value): yield (name, self._has_flag(value.flag)) - def _has_flag(self, o): + def _has_flag(self, o: int) -> bool: return (self.value & o) == o - def _set_flag(self, o, toggle): + def _set_flag(self, o: int, toggle: bool) -> None: if toggle is True: self.value |= o elif toggle is False: @@ -150,6 +170,7 @@ class SystemChannelFlags(BaseFlags): representing the currently available flags. You should query flags via the properties rather than using this raw value. """ + __slots__ = () # For some reason the flags for system channels are "inverted" @@ -157,10 +178,10 @@ class SystemChannelFlags(BaseFlags): # Since this is counter-intuitive from an API perspective and annoying # these will be inverted automatically - def _has_flag(self, o): + def _has_flag(self, o: int) -> bool: return (self.value & o) != o - def _set_flag(self, o, toggle): + def _set_flag(self, o: int, toggle: bool) -> None: if toggle is True: self.value &= ~o elif toggle is False: @@ -210,6 +231,7 @@ class MessageFlags(BaseFlags): representing the currently available flags. You should query flags via the properties rather than using this raw value. """ + __slots__ = () @flag_value @@ -346,7 +368,7 @@ class PublicUserFlags(BaseFlags): """ return UserFlags.verified_bot_developer.value - def all(self): + def all(self) -> List[UserFlags]: """List[:class:`UserFlags`]: Returns all public flags the user has.""" return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] @@ -393,7 +415,7 @@ class Intents(BaseFlags): __slots__ = () - def __init__(self, **kwargs): + def __init__(self, **kwargs: bool): self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: @@ -401,7 +423,7 @@ class Intents(BaseFlags): setattr(self, key, value) @classmethod - def all(cls): + def all(cls: Type[Intents]) -> Intents: """A factory method that creates a :class:`Intents` with everything enabled.""" bits = max(cls.VALID_FLAGS.values()).bit_length() value = (1 << bits) - 1 @@ -410,14 +432,14 @@ class Intents(BaseFlags): return self @classmethod - def none(cls): + def none(cls: Type[Intents]) -> Intents: """A factory method that creates a :class:`Intents` with everything disabled.""" self = cls.__new__(cls) self.value = self.DEFAULT_VALUE return self @classmethod - def default(cls): + def default(cls: Type[Intents]) -> Intents: """A factory method that creates a :class:`Intents` with everything enabled except :attr:`presences` and :attr:`members`. """ @@ -825,7 +847,7 @@ class MemberCacheFlags(BaseFlags): __slots__ = () - def __init__(self, **kwargs): + def __init__(self, **kwargs: bool): bits = max(self.VALID_FLAGS.values()).bit_length() self.value = (1 << bits) - 1 for key, value in kwargs.items(): @@ -834,7 +856,7 @@ class MemberCacheFlags(BaseFlags): setattr(self, key, value) @classmethod - def all(cls): + def all(cls: Type[MemberCacheFlags]) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` with everything enabled.""" bits = max(cls.VALID_FLAGS.values()).bit_length() value = (1 << bits) - 1 @@ -843,7 +865,7 @@ class MemberCacheFlags(BaseFlags): return self @classmethod - def none(cls): + def none(cls: Type[MemberCacheFlags]) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` with everything disabled.""" self = cls.__new__(cls) self.value = self.DEFAULT_VALUE @@ -886,7 +908,7 @@ class MemberCacheFlags(BaseFlags): return 4 @classmethod - def from_intents(cls, intents): + def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` based on the currently selected :class:`Intents`. @@ -914,7 +936,7 @@ class MemberCacheFlags(BaseFlags): return self - def _verify_intents(self, intents): + def _verify_intents(self, intents: Intents): if self.online and not intents.presences: raise ValueError('MemberCacheFlags.online requires Intents.presences enabled')