diff --git a/discord/interactions.py b/discord/interactions.py index 5a3297cf6..317ca2f0c 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -454,7 +454,7 @@ class Interaction: state = _InteractionMessageState(self, self._state) message = InteractionMessage(state=state, channel=self.channel, data=data) # type: ignore if view and not view.is_finished(): - self._state.store_view(view, message.id) + self._state.store_view(view, message.id, interaction_id=self.id) return message async def delete_original_message(self) -> None: @@ -682,7 +682,10 @@ class InteractionResponse: if ephemeral and view.timeout is None: view.timeout = 15 * 60.0 - self._parent._state.store_view(view) + # If the interaction type isn't an application command then there's no way + # to obtain this interaction_id again, so just default to None + entity_id = parent.id if parent.type is InteractionType.application_command else None + self._parent._state.store_view(view, entity_id) self._responded = True diff --git a/discord/message.py b/discord/message.py index d672108ac..921c5cb2e 100644 --- a/discord/message.py +++ b/discord/message.py @@ -882,7 +882,11 @@ class PartialMessage(Hashable): message = Message(state=self._state, channel=self.channel, data=data) if view and not view.is_finished(): - self._state.store_view(view, self.id) + interaction: Optional[MessageInteraction] = getattr(self, 'interaction', None) + if interaction is not None: + self._state.store_view(view, self.id, interaction_id=interaction.id) + else: + self._state.store_view(view, self.id) if delete_after is not None: await self.delete(delay=delete_after) diff --git a/discord/state.py b/discord/state.py index c9e21658f..d9c1a3d3b 100644 --- a/discord/state.py +++ b/discord/state.py @@ -361,7 +361,9 @@ class ConnectionState: self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view: View, message_id: Optional[int] = None) -> None: + def store_view(self, view: View, message_id: Optional[int] = None, interaction_id: Optional[int] = None) -> None: + if interaction_id is not None: + self._view_store.remove_interaction_mapping(interaction_id) self._view_store.add_view(view, message_id) def prevent_view_updates_for(self, message_id: int) -> Optional[View]: @@ -631,8 +633,14 @@ class ConnectionState: else: self.dispatch('raw_message_edit', raw) - if 'components' in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data['components']) + if 'components' in data: + try: + entity_id = int(data['interaction']['id']) + except (KeyError, ValueError): + entity_id = raw.message_id + + if self._view_store.is_message_tracked(entity_id): + self._view_store.update_from_message(entity_id, data['components']) def parse_message_reaction_add(self, data: gw.MessageReactionAddEvent) -> None: emoji = PartialEmoji.from_dict(data['emoji']) diff --git a/discord/ui/view.py b/discord/ui/view.py index 0195b9277..d89386483 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -184,6 +184,7 @@ class View: self._children: List[Item[Self]] = self._init_children() self.__weights = _ViewWeights(self._children) self.id: str = os.urandom(16).hex() + self._cache_key: Optional[int] = None self.__cancel_callback: Optional[Callable[[View], None]] = None self.__timeout_expiry: Optional[float] = None self.__timeout_task: Optional[asyncio.Task[None]] = None @@ -512,8 +513,8 @@ class View: class ViewStore: def __init__(self, state: ConnectionState): - # (component_type, message_id, custom_id): (View, Item) - self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {} + # entity_id: {(component_type, custom_id): Item} + self._views: Dict[Optional[int], Dict[Tuple[int, str], Item[View]]] = {} # message_id: View self._synced_message_views: Dict[int, View] = {} # custom_id: Modal @@ -524,34 +525,26 @@ class ViewStore: def persistent_views(self) -> Sequence[View]: # fmt: off views = { - view.id: view - for (_, (view, _)) in self._views.items() - if view.is_persistent() + item.view.id: item.view + for items in self._views.values() + for item in items.values() + if item.view and item.view.is_persistent() } # fmt: on return list(views.values()) - def __verify_integrity(self): - to_remove: List[Tuple[int, Optional[int], str]] = [] - for (k, (view, _)) in self._views.items(): - if view.is_finished(): - to_remove.append(k) - - for k in to_remove: - del self._views[k] - def add_view(self, view: View, message_id: Optional[int] = None) -> None: view._start_listening_from_store(self) if view.__discord_ui_modal__: self._modals[view.custom_id] = view # type: ignore return - self.__verify_integrity() - + dispatch_info = self._views.setdefault(message_id, {}) for item in view._children: if item.is_dispatchable(): - self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore + dispatch_info[(item.type.value, item.custom_id)] = item # type: ignore + view._cache_key = message_id if message_id is not None: self._synced_message_views[message_id] = view @@ -560,28 +553,62 @@ class ViewStore: self._modals.pop(view.custom_id, None) # type: ignore return - for item in view._children: - if item.is_dispatchable(): - self._views.pop((item.type.value, item.custom_id), None) # type: ignore + dispatch_info = self._views.get(view._cache_key) + if dispatch_info: + for item in view._children: + if item.is_dispatchable(): + dispatch_info.pop((item.type.value, item.custom_id), None) # type: ignore + + if len(dispatch_info) == 0: + self._views.pop(view._cache_key, None) - for key, value in self._synced_message_views.items(): - if value.id == view.id: - del self._synced_message_views[key] - break + self._synced_message_views.pop(view._cache_key, None) # type: ignore def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None: - self.__verify_integrity() - message_id: Optional[int] = interaction.message and interaction.message.id - key = (component_type, message_id, custom_id) - # Fallback to None message_id searches in case a persistent view - # was added without an associated message_id - value = self._views.get(key) or self._views.get((component_type, None, custom_id)) - if value is None: + interaction_id: Optional[int] = None + message_id: Optional[int] = None + # Realistically, in a component based interaction the Interaction.message will never be None + # However, this guard is just in case Discord screws up somehow + msg = interaction.message + if msg is not None: + message_id = msg.id + if msg.interaction: + interaction_id = msg.interaction.id + + key = (component_type, custom_id) + + # The entity_id can either be message_id, interaction_id, or None in that priority order. + item: Optional[Item[View]] = None + if message_id is not None: + item = self._views.get(message_id, {}).get(key) + + if item is None and interaction_id is not None: + try: + items = self._views.pop(interaction_id) + except KeyError: + item = None + else: + item = items.get(key) + # If we actually got the items, then these keys should probably be moved + # to the proper message_id instead of the interaction_id as they are now. + # An interaction_id is only used as a temporary stop gap for + # InteractionResponse.send_message so multiple view instances do not + # override each other. + # NOTE: Fix this mess if /callback endpoint ever gets proper return types + self._views.setdefault(message_id, {}).update(items) + + if item is None: + # Fallback to None message_id searches in case a persistent view + # was added without an associated message_id + item = self._views.get(None, {}).get(key) + + # If 3 lookups failed at this point then just discard it + if item is None: return - view, item = value item._refresh_state(interaction.data) # type: ignore - view._dispatch_item(item, interaction) + # Note, at this point the View is *not* None + item.view._dispatch_item(item, interaction) # type: ignore def dispatch_modal( self, @@ -597,6 +624,10 @@ class ViewStore: modal._refresh(components) modal._dispatch_submit(interaction) + def remove_interaction_mapping(self, interaction_id: int) -> None: + # This is called before re-adding the view + self._views.pop(interaction_id, None) + def is_message_tracked(self, message_id: int) -> bool: return message_id in self._synced_message_views