Browse Source

Use typing.Literal for channel and component type annotation

pull/7956/head
Lilly Rose Berner 3 years ago
committed by GitHub
parent
commit
7ee15e1d68
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 23
      discord/channel.py
  2. 8
      discord/components.py
  3. 8
      discord/threads.py
  4. 4
      discord/ui/button.py
  5. 4
      discord/ui/select.py
  6. 4
      discord/ui/text_input.py

23
discord/channel.py

@ -31,6 +31,7 @@ from typing import (
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
TYPE_CHECKING,
@ -165,7 +166,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self._type: int = data['type']
self._type: Literal[0, 5] = data['type']
self._update(guild, data)
def __repr__(self) -> str:
@ -190,7 +191,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
# Does this need coercion into `int`? No idea yet.
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
self._type: int = data.get('type', self._type)
self._type: Literal[0, 5] = data.get('type', self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data)
@ -198,9 +199,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return self
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.text, ChannelType.news]:
""":class:`ChannelType`: The channel's Discord type."""
return try_enum(ChannelType, self._type)
if self.type == 0:
return ChannelType.text
return ChannelType.news
@property
def _sorting_bucket(self) -> int:
@ -1036,7 +1039,7 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
return self
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.voice]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.voice
@ -1505,7 +1508,7 @@ class StageChannel(VocalGuildChannel):
return [member for member in self.members if self.permissions_for(member) >= required_permissions]
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.stage_voice]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.stage_voice
@ -1749,7 +1752,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return ChannelType.category.value
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.category]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.category
@ -2016,7 +2019,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
self._fill_overwrites(data)
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.forum]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.forum
@ -2330,7 +2333,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
return self
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.private]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.private
@ -2484,7 +2487,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
return f'<GroupChannel id={self.id} name={self.name!r}>'
@property
def type(self) -> ChannelType:
def type(self) -> Literal[ChannelType.group]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.group

8
discord/components.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle
from .utils import get_slots, MISSING
from .partial_emoji import PartialEmoji, _EmojiTag
@ -119,7 +119,7 @@ class ActionRow(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.type: Literal[ComponentType.action_row] = ComponentType.action_row
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
def to_dict(self) -> ActionRowPayload:
@ -170,7 +170,7 @@ class Button(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.type: Literal[ComponentType.button] = ComponentType.button
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
@ -244,7 +244,7 @@ class SelectMenu(Component):
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload):
self.type = ComponentType.select
self.type: Literal[ComponentType.select] = ComponentType.select
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)

8
discord/threads.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
from typing import Callable, Dict, Iterable, List, Literal, Optional, Union, TYPE_CHECKING
from datetime import datetime
from .mixins import Hashable
@ -58,6 +58,8 @@ if TYPE_CHECKING:
from .permissions import Permissions
from .state import ConnectionState
ThreadChannelType = Literal[ChannelType.news_thread, ChannelType.public_thread, ChannelType.private_thread]
class Thread(Messageable, Hashable):
"""Represents a Discord thread.
@ -172,7 +174,7 @@ class Thread(Messageable, Hashable):
self.parent_id: int = int(data['parent_id'])
self.owner_id: int = int(data['owner_id'])
self.name: str = data['name']
self._type: ChannelType = try_enum(ChannelType, data['type'])
self._type: ThreadChannelType = try_enum(ChannelType, data['type']) # type: ignore
self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.message_count: int = data['message_count']
@ -211,7 +213,7 @@ class Thread(Messageable, Hashable):
pass
@property
def type(self) -> ChannelType:
def type(self) -> ThreadChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return self._type

4
discord/ui/button.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Callable, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import Callable, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
import inspect
import os
@ -213,7 +213,7 @@ class Button(Item[V]):
)
@property
def type(self) -> ComponentType:
def type(self) -> Literal[ComponentType.button]:
return self._underlying.type
def to_component_dict(self) -> ButtonComponentPayload:

4
discord/ui/select.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Optional, TYPE_CHECKING, Tuple, TypeVar, Callable, Union
from typing import List, Literal, Optional, TYPE_CHECKING, Tuple, TypeVar, Callable, Union
import inspect
import os
@ -288,7 +288,7 @@ class Select(Item[V]):
)
@property
def type(self) -> ComponentType:
def type(self) -> Literal[ComponentType.select]:
return self._underlying.type
def is_dispatchable(self) -> bool:

4
discord/ui/text_input.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Literal, Optional, Tuple, TypeVar
from ..components import TextInput as TextInputComponent
from ..enums import ComponentType, TextStyle
@ -231,7 +231,7 @@ class TextInput(Item[V]):
)
@property
def type(self) -> ComponentType:
def type(self) -> Literal[ComponentType.text_input]:
return ComponentType.text_input
def is_dispatchable(self) -> bool:

Loading…
Cancel
Save