diff --git a/discord/ui/view.py b/discord/ui/view.py index 1f2e9848d..cd9c81958 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -49,11 +49,15 @@ import sys import time import os import copy + from .item import Item, ItemCallbackType from .dynamic import DynamicItem from ..components import ( Component, ActionRow as ActionRowComponent, + MediaGalleryItem, + SelectDefaultValue, + UnfurledMediaItem, _component_factory, Button as ButtonComponent, SelectMenu as SelectComponent, @@ -63,8 +67,11 @@ from ..components import ( FileComponent, SeparatorComponent, ThumbnailComponent, + SelectOption, ) -from ..utils import get as _utils_get +from ..utils import MISSING, get as _utils_get, _get_as_snowflake +from ..enums import SeparatorSize, TextStyle, try_enum, ButtonStyle +from ..emoji import PartialEmoji # fmt: off __all__ = ( @@ -80,7 +87,7 @@ if TYPE_CHECKING: from ..interactions import Interaction from ..message import Message - from ..types.components import ComponentBase as ComponentBasePayload + from ..types.components import ComponentBase as ComponentBasePayload, Component as ComponentPayload from ..types.interactions import ModalSubmitComponentInteractionData as ModalSubmitComponentInteractionDataPayload from ..state import ConnectionState from .modal import Modal @@ -100,6 +107,10 @@ def _walk_all_components(components: List[Component]) -> Iterator[Component]: def _component_to_item(component: Component) -> Item: + if isinstance(component, ActionRowComponent): + from .action_row import ActionRow + + return ActionRow.from_component(component) if isinstance(component, ButtonComponent): from .button import Button @@ -136,6 +147,141 @@ def _component_to_item(component: Component) -> Item: return Item.from_component(component) +def _component_data_to_item(data: ComponentPayload) -> Item: + if data['type'] == 1: + from .action_row import ActionRow + + return ActionRow( + *(_component_data_to_item(c) for c in data['components']), + id=data.get('id'), + ) + elif data['type'] == 2: + from .button import Button + + emoji = data.get('emoji') + + return Button( + style=try_enum(ButtonStyle, data['style']), + custom_id=data.get('custom_id'), + url=data.get('url'), + disabled=data.get('disabled', False), + emoji=PartialEmoji.from_dict(emoji) if emoji else None, + label=data.get('label'), + sku_id=_get_as_snowflake(data, 'sku_id'), + ) + elif data['type'] == 3: + from .select import Select + + return Select( + custom_id=data['custom_id'], + placeholder=data.get('placeholder'), + min_values=data.get('min_values', 1), + max_values=data.get('max_values', 1), + disabled=data.get('disabled', False), + id=data.get('id'), + options=[ + SelectOption.from_dict(o) + for o in data.get('options', []) + ], + ) + elif data['type'] == 4: + from .text_input import TextInput + + return TextInput( + label=data['label'], + style=try_enum(TextStyle, data['style']), + custom_id=data['custom_id'], + placeholder=data.get('placeholder'), + default=data.get('value'), + required=data.get('required', True), + min_length=data.get('min_length'), + max_length=data.get('max_length'), + id=data.get('id'), + ) + elif data['type'] in (5, 6, 7, 8): + from .select import ( + UserSelect, + RoleSelect, + MentionableSelect, + ChannelSelect, + ) + + cls_map: Dict[int, Type[Union[UserSelect, RoleSelect, MentionableSelect, ChannelSelect]]] = { + 5: UserSelect, + 6: RoleSelect, + 7: MentionableSelect, + 8: ChannelSelect, + } + + return cls_map[data['type']]( + custom_id=data['custom_id'], # type: ignore # will always be present in this point + placeholder=data.get('placeholder'), + min_values=data.get('min_values', 1), + max_values=data.get('max_values', 1), + disabled=data.get('disabled', False), + default_values=[ + SelectDefaultValue.from_dict(v) + for v in data.get('default_values', []) + ], + id=data.get('id'), + ) + elif data['type'] == 9: + from .section import Section + + return Section( + *(_component_data_to_item(c) for c in data['components']), + accessory=_component_data_to_item(data['accessory']), + id=data.get('id'), + ) + elif data['type'] == 10: + from .text_display import TextDisplay + + return TextDisplay(data['content'], id=data.get('id')) + elif data['type'] == 11: + from .thumbnail import Thumbnail + + return Thumbnail( + UnfurledMediaItem._from_data(data['media'], None), + description=data.get('description'), + spoiler=data.get('spoiler', False), + id=data.get('id'), + ) + elif data['type'] == 12: + from .media_gallery import MediaGallery + + return MediaGallery( + *(MediaGalleryItem._from_data(m, None) for m in data['items']), + id=data.get('id'), + ) + elif data['type'] == 13: + from .file import File + + return File( + UnfurledMediaItem._from_data(data['file'], None), + spoiler=data.get('spoiler', False), + id=data.get('id'), + ) + elif data['type'] == 14: + from .separator import Separator + + return Separator( + visible=data.get('divider', True), + spacing=try_enum(SeparatorSize, data.get('spacing', 1)), + id=data.get('id'), + ) + elif data['type'] == 17: + from .container import Container + + return Container( + *(_component_data_to_item(c) for c in data['components']), + accent_colour=data.get('accent_color'), + spoiler=data.get('spoiler', False), + id=data.get('type'), + ) + else: + raise ValueError(f'invalid item with type {data["type"]} provided') + + class _ViewWeights: # fmt: off __slots__ = ( @@ -599,6 +745,28 @@ class BaseView: # if it has this attribute then it can contain children yield from child.walk_children() # type: ignore + @classmethod + def _to_minimal_cls(cls) -> Type[Union[View, LayoutView]]: + if issubclass(cls, View): + return View + elif issubclass(cls, LayoutView): + return LayoutView + raise RuntimeError + + @classmethod + def from_dict(cls, data: List[ComponentPayload], *, timeout: Optional[float] = 180.0) -> Any: + cls = cls._to_minimal_cls() + self = cls(timeout=timeout) + + for raw in data: + item = _component_data_to_item(raw) + + if item._is_v2() and not self._is_v2(): + continue + + self.add_item(item) + return self + class View(BaseView): """Represents a UI view. @@ -616,6 +784,21 @@ class View(BaseView): __discord_ui_view__: ClassVar[bool] = True + if TYPE_CHECKING: + @classmethod + def from_dict(cls, data: List[ComponentPayload], *, timeout: Optional[float] = 180.0) -> View: + """Converts a :class:`list` of :class:`dict` s to a :class:`View` provided it is in the + format that Discord expects it to be in. + + You can find out about this format in the :ddocs:`official Discord documentation `. + + Parameters + ---------- + data: List[:class:`dict`] + The array of dictionaries to convert into a View. + """ + ... + def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -754,6 +937,21 @@ class LayoutView(BaseView): __discord_ui_layout_view__: ClassVar[bool] = True + if TYPE_CHECKING: + @classmethod + def from_dict(cls, data: List[ComponentPayload], *, timeout: Optional[float] = 180.0) -> LayoutView: + """Converts a :class:`list` of :class:`dict` s to a :class:`LayoutView` provided it is in the + format that Discord expects it to be in. + + You can find out about this format in the :ddocs:`official Discord documentation `. + + Parameters + ---------- + data: List[:class:`dict`] + The array of dictionaries to convert into a LayoutView. + """ + ... + def __init__(self, *, timeout: Optional[float] = 180.0) -> None: super().__init__(timeout=timeout)