From 7267d18d9e8afe14e2171c9c7b8a1d9bba57f896 Mon Sep 17 00:00:00 2001 From: Lilly Rose Berner Date: Mon, 16 May 2022 21:30:03 +0200 Subject: [PATCH] Improve component typing --- discord/components.py | 136 +++++++++++++++++++++++----------- discord/message.py | 19 +++-- discord/partial_emoji.py | 17 +++-- discord/state.py | 2 +- discord/types/components.py | 5 +- discord/types/interactions.py | 2 +- discord/ui/button.py | 1 - discord/ui/select.py | 1 - discord/ui/text_input.py | 3 +- discord/ui/view.py | 14 +++- 10 files changed, 131 insertions(+), 69 deletions(-) diff --git a/discord/components.py b/discord/components.py index 9e1d33d62..26bdc888e 100644 --- a/discord/components.py +++ b/discord/components.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, ClassVar, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union +from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from .enums import try_enum, ComponentType, ButtonStyle, TextStyle from .utils import get_slots, MISSING from .partial_emoji import PartialEmoji, _EmojiTag @@ -39,9 +39,12 @@ if TYPE_CHECKING: SelectOption as SelectOptionPayload, ActionRow as ActionRowPayload, TextInput as TextInputPayload, + ActionRowChildComponent as ActionRowChildComponentPayload, ) from .emoji import Emoji + ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput'] + __all__ = ( 'Component', @@ -61,26 +64,26 @@ class Component: - :class:`ActionRow` - :class:`Button` - :class:`SelectMenu` + - :class:`TextInput` This class is abstract and cannot be instantiated. .. versionadded:: 2.0 - - Attributes - ------------ - type: :class:`ComponentType` - The type of component. """ - __slots__: Tuple[str, ...] = ('type',) + __slots__: Tuple[str, ...] = () __repr_info__: ClassVar[Tuple[str, ...]] - type: ComponentType def __repr__(self) -> str: attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) return f'<{self.__class__.__name__} {attrs}>' + @property + def type(self) -> ComponentType: + """:class:`ComponentType`: The type of component.""" + raise NotImplementedError + @classmethod def _raw_construct(cls, **kwargs) -> Self: self = cls.__new__(cls) @@ -93,7 +96,7 @@ class Component: setattr(self, slot, value) return self - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> ComponentPayload: raise NotImplementedError @@ -108,9 +111,7 @@ class ActionRow(Component): Attributes ------------ - type: :class:`ComponentType` - The type of component. - children: List[:class:`Component`] + children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]] The children components that this holds, if any. """ @@ -118,15 +119,25 @@ class ActionRow(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: ComponentPayload): - self.type: Literal[ComponentType.action_row] = ComponentType.action_row - self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] + def __init__(self, data: ActionRowPayload, /) -> None: + self.children: List[ActionRowChildComponentType] = [] + + for component_data in data.get('components', []): + component = _component_factory(component_data) + + if component is not None: + self.children.append(component) + + @property + def type(self) -> Literal[ComponentType.action_row]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.action_row def to_dict(self) -> ActionRowPayload: return { - 'type': int(self.type), + 'type': self.type.value, 'components': [child.to_dict() for child in self.children], - } # type: ignore # Type checker does not understand these are the same + } class Button(Component): @@ -169,8 +180,7 @@ class Button(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: ButtonComponentPayload): - self.type: Literal[ComponentType.button] = ComponentType.button + def __init__(self, data: ButtonComponentPayload, /) -> None: self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.custom_id: Optional[str] = data.get('custom_id') self.url: Optional[str] = data.get('url') @@ -182,13 +192,21 @@ class Button(Component): except KeyError: self.emoji = None + @property + def type(self) -> Literal[ComponentType.button]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.button + def to_dict(self) -> ButtonComponentPayload: - payload = { + payload: ButtonComponentPayload = { 'type': 2, - 'style': int(self.style), - 'label': self.label, + 'style': self.style.value, 'disabled': self.disabled, } + + if self.label: + payload['label'] = self.label + if self.custom_id: payload['custom_id'] = self.custom_id @@ -198,7 +216,7 @@ class Button(Component): if self.emoji: payload['emoji'] = self.emoji.to_dict() - return payload # type: ignore # Type checker does not understand these are the same + return payload class SelectMenu(Component): @@ -243,8 +261,7 @@ class SelectMenu(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: SelectMenuPayload): - self.type: Literal[ComponentType.select] = ComponentType.select + def __init__(self, data: SelectMenuPayload, /) -> None: self.custom_id: str = data['custom_id'] self.placeholder: Optional[str] = data.get('placeholder') self.min_values: int = data.get('min_values', 1) @@ -252,6 +269,11 @@ class SelectMenu(Component): self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] self.disabled: bool = data.get('disabled', False) + @property + def type(self) -> Literal[ComponentType.select]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.select + def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { 'type': self.type.value, @@ -275,7 +297,7 @@ class SelectOption: .. versionadded:: 2.0 - Attributes + Parameters ----------- label: :class:`str` The label of the option. This is displayed to users. @@ -291,6 +313,23 @@ class SelectOption: The emoji of the option, if available. default: :class:`bool` Whether this option is selected by default. + + Attributes + ----------- + label: :class:`str` + The label of the option. This is displayed to users. + Can only be up to 100 characters. + value: :class:`str` + The value of the option. This is not displayed to users. + If not provided when constructed then it defaults to the + label. Can only be up to 100 characters. + description: Optional[:class:`str`] + An additional description of the option, if any. + Can only be up to 100 characters. + emoji: Optional[:class:`PartialEmoji`] + The emoji of the option, if available. + default: :class:`bool` + Whether this option is selected by default. """ __slots__: Tuple[str, ...] = ( @@ -322,7 +361,7 @@ class SelectOption: else: raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') - self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji + self.emoji: Optional[PartialEmoji] = emoji self.default: bool = default def __repr__(self) -> str: @@ -364,7 +403,7 @@ class SelectOption: } if self.emoji: - payload['emoji'] = self.emoji.to_dict() # type: ignore # This Dict[str, Any] is compatible with PartialEmoji + payload['emoji'] = self.emoji.to_dict() if self.description: payload['description'] = self.description @@ -414,8 +453,7 @@ class TextInput(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: TextInputPayload) -> None: - self.type: Literal[ComponentType.text_input] = ComponentType.text_input + def __init__(self, data: TextInputPayload, /) -> None: self.style: TextStyle = try_enum(TextStyle, data['style']) self.label: str = data['label'] self.custom_id: str = data['custom_id'] @@ -425,6 +463,11 @@ class TextInput(Component): self.min_length: Optional[int] = data.get('min_length') self.max_length: Optional[int] = data.get('max_length') + @property + def type(self) -> Literal[ComponentType.text_input]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.text_input + def to_dict(self) -> TextInputPayload: payload: TextInputPayload = { 'type': self.type.value, @@ -457,19 +500,22 @@ class TextInput(Component): return self.value -def _component_factory(data: ComponentPayload) -> Component: - component_type = data['type'] - if component_type == 1: +@overload +def _component_factory(data: ActionRowChildComponentPayload) -> Optional[ActionRowChildComponentType]: + ... + + +@overload +def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, ActionRowChildComponentType]]: + ... + + +def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, ActionRowChildComponentType]]: + if data['type'] == 1: return ActionRow(data) - elif component_type == 2: - # The type checker does not properly do narrowing here. - return Button(data) # type: ignore - elif component_type == 3: - # The type checker does not properly do narrowing here. - return SelectMenu(data) # type: ignore - elif component_type == 4: - # The type checker does not properly do narrowing here. - return TextInput(data) # type: ignore - else: - as_enum = try_enum(ComponentType, component_type) - return Component._raw_construct(type=as_enum) + elif data['type'] == 2: + return Button(data) + elif data['type'] == 3: + return SelectMenu(data) + elif data['type'] == 4: + return TextInput(data) diff --git a/discord/message.py b/discord/message.py index 359fc051a..cfa1f969b 100644 --- a/discord/message.py +++ b/discord/message.py @@ -87,7 +87,7 @@ if TYPE_CHECKING: from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent from .abc import Snowflake from .abc import GuildChannel, MessageableChannel - from .components import Component + from .components import ActionRow, ActionRowChildComponentType from .state import ConnectionState from .channel import TextChannel from .mentions import AllowedMentions @@ -96,6 +96,7 @@ if TYPE_CHECKING: from .ui.view import View EmojiInputType = Union[Emoji, PartialEmoji, str] + MessageComponentType = Union[ActionRow, ActionRowChildComponentType] __all__ = ( @@ -1340,7 +1341,7 @@ class Message(PartialMessage, Hashable): A list of sticker items given to the message. .. versionadded:: 1.6 - components: List[:class:`Component`] + components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]] A list of components in the message. .. versionadded:: 2.0 @@ -1392,6 +1393,7 @@ class Message(PartialMessage, Hashable): mentions: List[Union[User, Member]] author: Union[User, Member] role_mentions: List[Role] + components: List[MessageComponentType] def __init__( self, @@ -1418,7 +1420,6 @@ class Message(PartialMessage, Hashable): self.content: str = data['content'] self.nonce: Optional[Union[int, str]] = data.get('nonce') self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] - self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])] try: # if the channel doesn't have a guild attribute, we handle that @@ -1460,7 +1461,7 @@ class Message(PartialMessage, Hashable): # the channel will be the correct type here ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore - for handler in ('author', 'member', 'mentions', 'mention_roles'): + for handler in ('author', 'member', 'mentions', 'mention_roles', 'components'): try: getattr(self, f'_handle_{handler}')(data[handler]) except KeyError: @@ -1631,8 +1632,14 @@ class Message(PartialMessage, Hashable): if role is not None: self.role_mentions.append(role) - def _handle_components(self, components: List[ComponentPayload]): - self.components = [_component_factory(d) for d in components] + def _handle_components(self, data: List[ComponentPayload]) -> None: + self.components = [] + + for component_data in data: + component = _component_factory(component_data) + + if component is not None: + self.components.append(component) def _handle_interaction(self, data: MessageInteractionPayload): self.interaction = MessageInteraction(state=self._state, guild=self.guild, data=data) diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index 1fc31497f..5e0aea366 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime - from .types.message import PartialEmoji as PartialEmojiPayload + from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload from .types.activity import ActivityEmoji @@ -148,13 +148,16 @@ class PartialEmoji(_EmojiTag, AssetMixin): return cls(name=value, id=None, animated=False) - def to_dict(self) -> Dict[str, Any]: - o: Dict[str, Any] = {'name': self.name} - if self.id: - o['id'] = self.id + def to_dict(self) -> EmojiPayload: + payload: EmojiPayload = { + 'id': self.id, + 'name': self.name, + } + if self.animated: - o['animated'] = self.animated - return o + payload['animated'] = self.animated + + return payload def _to_partial(self) -> PartialEmoji: return self diff --git a/discord/state.py b/discord/state.py index 196d88c5d..dab982a95 100644 --- a/discord/state.py +++ b/discord/state.py @@ -730,7 +730,7 @@ class ConnectionState: inner_data = data['data'] custom_id = inner_data['custom_id'] components = inner_data['components'] - self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore + self._view_store.dispatch_modal(custom_id, interaction, components) self.dispatch('interaction', interaction) def parse_presence_update(self, data: gw.PresenceUpdateEvent) -> None: diff --git a/discord/types/components.py b/discord/types/components.py index 9c197bebd..697490bd6 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -36,7 +36,7 @@ TextStyle = Literal[1, 2] class ActionRow(TypedDict): type: Literal[1] - components: List[Component] + components: List[ActionRowChildComponent] class ButtonComponent(TypedDict): @@ -79,4 +79,5 @@ class TextInput(TypedDict): max_length: NotRequired[int] -Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput] +ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput] +Component = Union[ActionRow, ActionRowChildComponent] diff --git a/discord/types/interactions.py b/discord/types/interactions.py index 293b9ac27..b5ab32cac 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -186,7 +186,7 @@ ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData, class ModalSubmitInteractionData(TypedDict): custom_id: str - components: List[ModalSubmitActionRowInteractionData] + components: List[ModalSubmitComponentInteractionData] InteractionData = Union[ diff --git a/discord/ui/button.py b/discord/ui/button.py index 7622d9756..8844b9605 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -120,7 +120,6 @@ class Button(Item[V]): raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') self._underlying = ButtonComponent._raw_construct( - type=ComponentType.button, custom_id=custom_id, url=url, disabled=disabled, diff --git a/discord/ui/select.py b/discord/ui/select.py index 50d56f845..72148ccc3 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -117,7 +117,6 @@ class Select(Item[V]): options = [] if options is MISSING else options self._underlying = SelectMenu._raw_construct( custom_id=custom_id, - type=ComponentType.select, placeholder=placeholder, min_values=min_values, max_values=max_values, diff --git a/discord/ui/text_input.py b/discord/ui/text_input.py index a5611b2f2..ec42546a8 100644 --- a/discord/ui/text_input.py +++ b/discord/ui/text_input.py @@ -114,7 +114,6 @@ class TextInput(Item[V]): raise TypeError(f'expected custom_id to be str not {custom_id.__class__!r}') self._underlying = TextInputComponent._raw_construct( - type=ComponentType.text_input, label=label, style=style, custom_id=custom_id, @@ -238,7 +237,7 @@ class TextInput(Item[V]): @property def type(self) -> Literal[ComponentType.text_input]: - return ComponentType.text_input + return self._underlying.type def is_dispatchable(self) -> bool: return False diff --git a/discord/ui/view.py b/discord/ui/view.py index d89386483..5a13827f5 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -281,7 +281,7 @@ class View: one of its subclasses. """ view = View(timeout=timeout) - for component in _walk_all_components(message.components): + for component in _walk_all_components(message.components): # type: ignore view.add_item(_component_to_item(component)) return view @@ -634,7 +634,15 @@ class ViewStore: def remove_message_tracking(self, message_id: int) -> Optional[View]: return self._synced_message_views.pop(message_id, None) - def update_from_message(self, message_id: int, components: List[ComponentPayload]) -> None: + def update_from_message(self, message_id: int, data: List[ComponentPayload]) -> None: + components: List[Component] = [] + + for component_data in data: + component = _component_factory(component_data) + + if component is not None: + components.append(component) + # pre-req: is_message_tracked == true view = self._synced_message_views[message_id] - view._refresh([_component_factory(d) for d in components]) + view._refresh(components)