From e9e2d8cb1c05b2562f6a16a27b59432ff04cef68 Mon Sep 17 00:00:00 2001 From: Lilly Rose Berner Date: Mon, 16 May 2022 21:30:03 +0200 Subject: [PATCH] Improve component typing --- discord/components.py | 96 +++++++++++++++++++++-------------- discord/message.py | 19 ++++--- discord/partial_emoji.py | 17 ++++--- discord/types/components.py | 5 +- discord/types/interactions.py | 2 +- 5 files changed, 85 insertions(+), 54 deletions(-) diff --git a/discord/components.py b/discord/components.py index e97c415f6..ffd355120 100644 --- a/discord/components.py +++ b/discord/components.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Union +from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, InteractionType from .interactions import _wrapped_interaction @@ -39,13 +39,15 @@ if TYPE_CHECKING: ButtonComponent as ButtonComponentPayload, SelectMenu as SelectMenuPayload, SelectOption as SelectOptionPayload, - ActionRow as ActionRowPayload, TextInput as TextInputPayload, + ActionRowChildComponent as ActionRowChildComponentPayload, ) from .emoji import Emoji from .interactions import Interaction from .message import Message + ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput'] + __all__ = ( 'Component', @@ -68,23 +70,22 @@ class Component: - :class:`TextInput` .. versionadded:: 2.0 - - Attributes - ------------ - type: :class:`ComponentType` - The type of component. """ __slots__: Tuple[str, ...] = ('type', 'message') __repr_info__: ClassVar[Tuple[str, ...]] - type: ComponentType message: Message def __repr__(self) -> str: attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) return f'<{self.__class__.__name__} {attrs}>' + @property + def type(self) -> ComponentType: + """:class:`ComponentType`: The type of component.""" + raise NotImplementedError + @classmethod def _raw_construct(cls, **kwargs) -> Self: self = cls.__new__(cls) @@ -97,7 +98,7 @@ class Component: setattr(self, slot, value) return self - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ComponentPayload: raise NotImplementedError @@ -112,9 +113,7 @@ class ActionRow(Component): Attributes ------------ - type: :class:`ComponentType` - The type of component. - children: List[:class:`Component`] + children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]] The children components that this holds, if any. message: :class:`Message` The originating message. @@ -126,14 +125,18 @@ class ActionRow(Component): def __init__(self, data: ComponentPayload, message: Message): self.message = message - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.children: List[Component] = [_component_factory(d, message) for d in data.get('components', [])] + self.children: List[ActionRowChildComponentType] = [] - def to_dict(self) -> ActionRowPayload: - return { - 'type': int(self.type), - 'components': [child.to_dict() for child in self.children], - } # type: ignore # Type checker does not understand these are the same + for component_data in data.get('components', []): + component = _component_factory(component_data) + + if component is not None: + self.children.append(component) + + @property + def type(self) -> Literal[ComponentType.action_row]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.action_row class Button(Component): @@ -175,7 +178,6 @@ class Button(Component): def __init__(self, data: ButtonComponentPayload, message: Message): self.message = message - self.type: ComponentType = try_enum(ComponentType, data['type']) self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.custom_id: Optional[str] = data.get('custom_id') self.url: Optional[str] = data.get('url') @@ -193,6 +195,11 @@ class Button(Component): 'custom_id': self.custom_id, } + @property + def type(self) -> Literal[ComponentType.button]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.button + async def click(self) -> Union[str, Interaction]: """|coro| @@ -268,7 +275,6 @@ class SelectMenu(Component): def __init__(self, data: SelectMenuPayload, message: Message): self.message = message - self.type = ComponentType.select self.custom_id: str = data['custom_id'] self.placeholder: Optional[str] = data.get('placeholder') self.min_values: int = data.get('min_values', 1) @@ -277,6 +283,11 @@ class SelectMenu(Component): self.disabled: bool = data.get('disabled', False) self.hash: str = data.get('hash', '') + @property + def type(self) -> Literal[ComponentType.select]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.select + def to_dict(self, options: Tuple[SelectOption]) -> dict: return { 'component_type': self.type.value, @@ -333,7 +344,7 @@ class SelectOption: description: Optional[:class:`str`] An additional description of the option, if any. Can only be up to 100 characters. - emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]] + emoji: Optional[:class:`PartialEmoji`] The emoji of the option, if available. default: :class:`bool` Whether this option is selected by default. @@ -368,7 +379,7 @@ class SelectOption: else: raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') - self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji + self.emoji: Optional[PartialEmoji] = emoji self.default: bool = default def __repr__(self) -> str: @@ -440,8 +451,7 @@ class TextInput(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: TextInputPayload, _=MISSING) -> None: - self.type: ComponentType = ComponentType.text_input + def __init__(self, data: TextInputPayload, *args) -> None: self.style: TextStyle = try_enum(TextStyle, data['style']) self.label: str = data['label'] self.custom_id: str = data['custom_id'] @@ -451,6 +461,11 @@ class TextInput(Component): self.min_length: Optional[int] = data.get('min_length') self.max_length: Optional[int] = data.get('max_length') + @property + def type(self) -> Literal[ComponentType.text_input]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.text_input + def to_dict(self) -> dict: return { 'type': self.type.value, @@ -500,17 +515,22 @@ class TextInput(Component): self.value = value -def _component_factory(data: ComponentPayload, message: Message = MISSING) -> Component: - # The type checker does not properly do narrowing here - component_type = data['type'] - if component_type == 1: +@overload +def _component_factory(data: ActionRowChildComponentPayload, message: Message = ...) -> Optional[ActionRowChildComponentType]: + ... + + +@overload +def _component_factory(data: ComponentPayload, message: Message = ...) -> Optional[Union[ActionRow, ActionRowChildComponentType]]: + ... + + +def _component_factory(data: ComponentPayload, message: Message = MISSING) -> Optional[Union[ActionRow, ActionRowChildComponentType]]: + if data['type'] == 1: return ActionRow(data, message) - elif component_type == 2: - return Button(data, message) # type: ignore - elif component_type == 3: - return SelectMenu(data, message) # type: ignore - elif component_type == 4: - return TextInput(data, message) # type: ignore - else: - as_enum = try_enum(ComponentType, component_type) - return Component._raw_construct(type=as_enum) + elif data['type'] == 2: + return Button(data, message) + elif data['type'] == 3: + return SelectMenu(data, message) + elif data['type'] == 4: + return TextInput(data, message) diff --git a/discord/message.py b/discord/message.py index 05633e477..d74035403 100644 --- a/discord/message.py +++ b/discord/message.py @@ -94,7 +94,7 @@ if TYPE_CHECKING: from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent from .abc import Snowflake from .abc import GuildChannel, MessageableChannel - from .components import ActionRow + from .components import ActionRow, ActionRowChildComponentType from .state import ConnectionState from .channel import TextChannel from .mentions import AllowedMentions @@ -102,6 +102,7 @@ if TYPE_CHECKING: from .role import Role EmojiInputType = Union[Emoji, PartialEmoji, str] + MessageComponentType = Union[ActionRow, ActionRowChildComponentType] __all__ = ( @@ -1254,7 +1255,7 @@ class Message(PartialMessage, Hashable): A list of sticker items given to the message. .. versionadded:: 1.6 - components: List[:class:`Component`] + components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]] A list of components in the message. .. versionadded:: 2.0 @@ -1311,6 +1312,7 @@ class Message(PartialMessage, Hashable): mentions: List[Union[User, Member]] author: Union[User, Member] role_mentions: List[Role] + components: List[MessageComponentType] def __init__( self, @@ -1337,7 +1339,6 @@ class Message(PartialMessage, Hashable): self.content: str = data['content'] self.nonce: Optional[Union[int, str]] = data.get('nonce') self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] - self.components: List[ActionRow] = [_component_factory(d, self) for d in data.get('components', [])] # type: ignore # Will always be rows here self.call: Optional[CallMessage] = None try: @@ -1387,7 +1388,7 @@ class Message(PartialMessage, Hashable): # The channel will be the correct type here ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore - for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction'): + for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction', 'components'): try: getattr(self, f'_handle_{handler}')(data[handler]) except KeyError: @@ -1579,8 +1580,14 @@ class Message(PartialMessage, Hashable): call['participants'] = participants self.call = CallMessage(message=self, **call) - def _handle_components(self, components: List[ComponentPayload]): - self.components: List[ActionRow] = [_component_factory(d, self) for d in components] # type: ignore # Will always be rows here + def _handle_components(self, data: List[ComponentPayload]) -> None: + self.components = [] + + for component_data in data: + component = _component_factory(component_data, self) + + if component is not None: + self.components.append(component) def _handle_interaction(self, data: MessageInteractionPayload): self.interaction = Interaction._from_message(self, **data) diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index 1f3def16c..7341df938 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime - from .types.message import PartialEmoji as PartialEmojiPayload + from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload from .types.activity import ActivityEmoji @@ -148,13 +148,16 @@ class PartialEmoji(_EmojiTag, AssetMixin): return cls(name=value, id=None, animated=False) - def to_dict(self) -> Dict[str, Any]: - o: Dict[str, Any] = {'name': self.name} - if self.id: - o['id'] = self.id + def to_dict(self) -> EmojiPayload: + payload: EmojiPayload = { + 'id': self.id, + 'name': self.name, + } + if self.animated: - o['animated'] = self.animated - return o + payload['animated'] = self.animated + + return payload def _to_partial(self) -> PartialEmoji: return self diff --git a/discord/types/components.py b/discord/types/components.py index 9c197bebd..697490bd6 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -36,7 +36,7 @@ TextStyle = Literal[1, 2] class ActionRow(TypedDict): type: Literal[1] - components: List[Component] + components: List[ActionRowChildComponent] class ButtonComponent(TypedDict): @@ -79,4 +79,5 @@ class TextInput(TypedDict): max_length: NotRequired[int] -Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput] +ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput] +Component = Union[ActionRow, ActionRowChildComponent] diff --git a/discord/types/interactions.py b/discord/types/interactions.py index 293b9ac27..b5ab32cac 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -186,7 +186,7 @@ ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData, class ModalSubmitInteractionData(TypedDict): custom_id: str - components: List[ModalSubmitActionRowInteractionData] + components: List[ModalSubmitComponentInteractionData] InteractionData = Union[