diff --git a/discord/ui/action_row.py b/discord/ui/action_row.py index a27e31ca9..7e6a6a37c 100644 --- a/discord/ui/action_row.py +++ b/discord/ui/action_row.py @@ -601,5 +601,5 @@ class ActionRow(Item[V]): self = cls() for cmp in component.children: - self.add_item(_component_to_item(cmp)) + self.add_item(_component_to_item(cmp, self)) return self diff --git a/discord/ui/button.py b/discord/ui/button.py index 9d68d411b..6950dadee 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -152,7 +152,6 @@ class Button(Item[V]): sku_id=sku_id, id=id, ) - self._parent: Optional[ActionRow] = None self.row = row self.id = id diff --git a/discord/ui/container.py b/discord/ui/container.py index 93dedfbe1..838661ff3 100644 --- a/discord/ui/container.py +++ b/discord/ui/container.py @@ -321,12 +321,17 @@ class Container(Item[V]): @classmethod def from_component(cls, component: ContainerComponent) -> Self: - return cls( - *[_component_to_item(c) for c in component.children], + self = cls( accent_colour=component.accent_colour, spoiler=component.spoiler, id=component.id, ) + self._children = [ + _component_to_item( + cmp, self + ) for cmp in component.children + ] + return self def walk_children(self) -> Generator[Item[V], None, None]: """An iterator that recursively walks through all the children of this container diff --git a/discord/ui/section.py b/discord/ui/section.py index 708ef68c5..320dbf4da 100644 --- a/discord/ui/section.py +++ b/discord/ui/section.py @@ -236,13 +236,14 @@ class Section(Item[V]): @classmethod def from_component(cls, component: SectionComponent) -> Self: - from .view import _component_to_item # >circular import< + from .view import _component_to_item - return cls( - *[_component_to_item(c) for c in component.components], - accessory=_component_to_item(component.accessory), - id=component.id, - ) + self = cls.__new__(cls) + self.accessory = _component_to_item(component.accessory, self) + self.id = component.id + self._children = [_component_to_item(c, self) for c in component.components] + + return self def to_components(self) -> List[Dict[str, Any]]: components = [] diff --git a/discord/ui/select.py b/discord/ui/select.py index 7695f759e..31f16bd88 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -264,7 +264,6 @@ class BaseSelect(Item[V]): self.row = row self.id = id - self._parent: Optional[ActionRow] = None self._values: List[PossibleValue] = [] @property diff --git a/discord/ui/view.py b/discord/ui/view.py index 273c5603e..d3eb45086 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -68,6 +68,7 @@ from ..components import ( SeparatorComponent, ThumbnailComponent, SelectOption, + Container as ContainerComponent, ) from ..utils import get as _utils_get, _get_as_snowflake from ..enums import SeparatorSpacing, TextStyle, try_enum, ButtonStyle @@ -106,52 +107,59 @@ def _walk_all_components(components: List[Component]) -> Iterator[Component]: yield item -def _component_to_item(component: Component) -> Item: +def _component_to_item(component: Component, parent: Optional[Item] = None) -> Item: if isinstance(component, ActionRowComponent): from .action_row import ActionRow - return ActionRow.from_component(component) - if isinstance(component, ButtonComponent): + item = ActionRow.from_component(component) + elif isinstance(component, ButtonComponent): from .button import Button - return Button.from_component(component) - if isinstance(component, SelectComponent): + item = Button.from_component(component) + elif isinstance(component, SelectComponent): from .select import BaseSelect - return BaseSelect.from_component(component) - if isinstance(component, SectionComponent): + item = BaseSelect.from_component(component) + elif isinstance(component, SectionComponent): from .section import Section - return Section.from_component(component) - if isinstance(component, TextDisplayComponent): + item = Section.from_component(component) + elif isinstance(component, TextDisplayComponent): from .text_display import TextDisplay - return TextDisplay.from_component(component) - if isinstance(component, MediaGalleryComponent): + item = TextDisplay.from_component(component) + elif isinstance(component, MediaGalleryComponent): from .media_gallery import MediaGallery - return MediaGallery.from_component(component) - if isinstance(component, FileComponent): + item = MediaGallery.from_component(component) + elif isinstance(component, FileComponent): from .file import File - return File.from_component(component) - if isinstance(component, SeparatorComponent): + item = File.from_component(component) + elif isinstance(component, SeparatorComponent): from .separator import Separator - return Separator.from_component(component) - if isinstance(component, ThumbnailComponent): + item = Separator.from_component(component) + elif isinstance(component, ThumbnailComponent): from .thumbnail import Thumbnail - return Thumbnail.from_component(component) + item = Thumbnail.from_component(component) + elif isinstance(component, ContainerComponent): + from .container import Container + + item = Container.from_component(component) + else: + item = Item.from_component(component) - return Item.from_component(component) + item._parent = parent + return item -def _component_data_to_item(data: ComponentPayload) -> Item: +def _component_data_to_item(data: ComponentPayload, parent: Optional[Item] = None) -> Item: if data['type'] == 1: from .action_row import ActionRow - return ActionRow( + item = ActionRow( *(_component_data_to_item(c) for c in data['components']), id=data.get('id'), ) @@ -160,7 +168,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: emoji = data.get('emoji') - return Button( + item = Button( style=try_enum(ButtonStyle, data['style']), custom_id=data.get('custom_id'), url=data.get('url'), @@ -172,7 +180,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 3: from .select import Select - return Select( + item = Select( custom_id=data['custom_id'], placeholder=data.get('placeholder'), min_values=data.get('min_values', 1), @@ -184,7 +192,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 4: from .text_input import TextInput - return TextInput( + item = TextInput( label=data['label'], style=try_enum(TextStyle, data['style']), custom_id=data['custom_id'], @@ -210,7 +218,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: 8: ChannelSelect, } - return cls_map[data['type']]( + item = cls_map[data['type']]( custom_id=data['custom_id'], # type: ignore # will always be present in this point placeholder=data.get('placeholder'), min_values=data.get('min_values', 1), @@ -222,7 +230,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 9: from .section import Section - return Section( + item = Section( *(_component_data_to_item(c) for c in data['components']), accessory=_component_data_to_item(data['accessory']), id=data.get('id'), @@ -230,11 +238,11 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 10: from .text_display import TextDisplay - return TextDisplay(data['content'], id=data.get('id')) + item = TextDisplay(data['content'], id=data.get('id')) elif data['type'] == 11: from .thumbnail import Thumbnail - return Thumbnail( + item = Thumbnail( UnfurledMediaItem._from_data(data['media'], None), description=data.get('description'), spoiler=data.get('spoiler', False), @@ -243,14 +251,14 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 12: from .media_gallery import MediaGallery - return MediaGallery( + item = MediaGallery( *(MediaGalleryItem._from_data(m, None) for m in data['items']), id=data.get('id'), ) elif data['type'] == 13: from .file import File - return File( + item = File( UnfurledMediaItem._from_data(data['file'], None), spoiler=data.get('spoiler', False), id=data.get('id'), @@ -258,7 +266,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 14: from .separator import Separator - return Separator( + item = Separator( visible=data.get('divider', True), spacing=try_enum(SeparatorSpacing, data.get('spacing', 1)), id=data.get('id'), @@ -266,7 +274,7 @@ def _component_data_to_item(data: ComponentPayload) -> Item: elif data['type'] == 17: from .container import Container - return Container( + item = Container( *(_component_data_to_item(c) for c in data['components']), accent_colour=data.get('accent_color'), spoiler=data.get('spoiler', False), @@ -275,6 +283,9 @@ def _component_data_to_item(data: ComponentPayload) -> Item: else: raise ValueError(f'invalid item with type {data["type"]} provided') + item._parent = parent + return item + class _ViewWeights: # fmt: off @@ -1120,7 +1131,7 @@ class ViewStore: try: base_item_index, base_item = next( (index, child) - for index, child in enumerate(view._children) + for index, child in enumerate(view.walk_children()) if child.type.value == component_type and getattr(child, 'custom_id', None) == custom_id ) except StopIteration: @@ -1132,8 +1143,11 @@ class ViewStore: _log.exception('Ignoring exception in dynamic item creation for %r', factory) return - # Swap the item in the view with our new dynamic item - view._children[base_item_index] = item # type: ignore + # Swap the item in the view or parent with our new dynamic item + if base_item._parent: + base_item._parent._children[base_item_index] = item # type: ignore + else: + view._children[base_item_index] = item # type: ignore item._view = view item._rendered_row = base_item._rendered_row item._refresh_state(interaction, interaction.data) # type: ignore