From 7ee15e1d684ac697e9e2228ffba2bd2e7b3c618f Mon Sep 17 00:00:00 2001 From: Lilly Rose Berner Date: Fri, 29 Apr 2022 12:07:22 +0200 Subject: [PATCH] Use typing.Literal for channel and component type annotation --- discord/channel.py | 23 +++++++++++++---------- discord/components.py | 8 ++++---- discord/threads.py | 8 +++++--- discord/ui/button.py | 4 ++-- discord/ui/select.py | 4 ++-- discord/ui/text_input.py | 4 ++-- 6 files changed, 28 insertions(+), 23 deletions(-) diff --git a/discord/channel.py b/discord/channel.py index 236398769..8542f14cb 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -31,6 +31,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, TYPE_CHECKING, @@ -165,7 +166,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): self._state: ConnectionState = state self.id: int = int(data['id']) - self._type: int = data['type'] + self._type: Literal[0, 5] = data['type'] self._update(guild, data) def __repr__(self) -> str: @@ -190,7 +191,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): # Does this need coercion into `int`? No idea yet. self.slowmode_delay: int = data.get('rate_limit_per_user', 0) self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) - self._type: int = data.get('type', self._type) + self._type: Literal[0, 5] = data.get('type', self._type) self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self._fill_overwrites(data) @@ -198,9 +199,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return self @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.text, ChannelType.news]: """:class:`ChannelType`: The channel's Discord type.""" - return try_enum(ChannelType, self._type) + if self.type == 0: + return ChannelType.text + return ChannelType.news @property def _sorting_bucket(self) -> int: @@ -1036,7 +1039,7 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel): return self @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.voice]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.voice @@ -1505,7 +1508,7 @@ class StageChannel(VocalGuildChannel): return [member for member in self.members if self.permissions_for(member) >= required_permissions] @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.stage_voice]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.stage_voice @@ -1749,7 +1752,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return ChannelType.category.value @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.category]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.category @@ -2016,7 +2019,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): self._fill_overwrites(data) @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.forum]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.forum @@ -2330,7 +2333,7 @@ class DMChannel(discord.abc.Messageable, Hashable): return self @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.private]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.private @@ -2484,7 +2487,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): return f'' @property - def type(self) -> ChannelType: + def type(self) -> Literal[ChannelType.group]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.group diff --git a/discord/components.py b/discord/components.py index 853e0b401..fbb87aba1 100644 --- a/discord/components.py +++ b/discord/components.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union from .enums import try_enum, ComponentType, ButtonStyle, TextStyle from .utils import get_slots, MISSING from .partial_emoji import PartialEmoji, _EmojiTag @@ -119,7 +119,7 @@ class ActionRow(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) + self.type: Literal[ComponentType.action_row] = ComponentType.action_row self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] def to_dict(self) -> ActionRowPayload: @@ -170,7 +170,7 @@ class Button(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ButtonComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) + self.type: Literal[ComponentType.button] = ComponentType.button self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.custom_id: Optional[str] = data.get('custom_id') self.url: Optional[str] = data.get('url') @@ -244,7 +244,7 @@ class SelectMenu(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: SelectMenuPayload): - self.type = ComponentType.select + self.type: Literal[ComponentType.select] = ComponentType.select self.custom_id: str = data['custom_id'] self.placeholder: Optional[str] = data.get('placeholder') self.min_values: int = data.get('min_values', 1) diff --git a/discord/threads.py b/discord/threads.py index 6451bca21..8ec0adcbc 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING +from typing import Callable, Dict, Iterable, List, Literal, Optional, Union, TYPE_CHECKING from datetime import datetime from .mixins import Hashable @@ -58,6 +58,8 @@ if TYPE_CHECKING: from .permissions import Permissions from .state import ConnectionState + ThreadChannelType = Literal[ChannelType.news_thread, ChannelType.public_thread, ChannelType.private_thread] + class Thread(Messageable, Hashable): """Represents a Discord thread. @@ -172,7 +174,7 @@ class Thread(Messageable, Hashable): self.parent_id: int = int(data['parent_id']) self.owner_id: int = int(data['owner_id']) self.name: str = data['name'] - self._type: ChannelType = try_enum(ChannelType, data['type']) + self._type: ThreadChannelType = try_enum(ChannelType, data['type']) # type: ignore self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id') self.slowmode_delay: int = data.get('rate_limit_per_user', 0) self.message_count: int = data['message_count'] @@ -211,7 +213,7 @@ class Thread(Messageable, Hashable): pass @property - def type(self) -> ChannelType: + def type(self) -> ThreadChannelType: """:class:`ChannelType`: The channel's Discord type.""" return self._type diff --git a/discord/ui/button.py b/discord/ui/button.py index 8de338da4..7622d9756 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Callable, Optional, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import Callable, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union import inspect import os @@ -213,7 +213,7 @@ class Button(Item[V]): ) @property - def type(self) -> ComponentType: + def type(self) -> Literal[ComponentType.button]: return self._underlying.type def to_component_dict(self) -> ButtonComponentPayload: diff --git a/discord/ui/select.py b/discord/ui/select.py index 170f35093..50d56f845 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import List, Optional, TYPE_CHECKING, Tuple, TypeVar, Callable, Union +from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Callable, Union import inspect import os @@ -288,7 +288,7 @@ class Select(Item[V]): ) @property - def type(self) -> ComponentType: + def type(self) -> Literal[ComponentType.select]: return self._underlying.type def is_dispatchable(self) -> bool: diff --git a/discord/ui/text_input.py b/discord/ui/text_input.py index 6a798eea4..3952affef 100644 --- a/discord/ui/text_input.py +++ b/discord/ui/text_input.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import os -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Literal, Optional, Tuple, TypeVar from ..components import TextInput as TextInputComponent from ..enums import ComponentType, TextStyle @@ -231,7 +231,7 @@ class TextInput(Item[V]): ) @property - def type(self) -> ComponentType: + def type(self) -> Literal[ComponentType.text_input]: return ComponentType.text_input def is_dispatchable(self) -> bool: