From a0dfdb9b1d64fa23788069fa7b49c941f5ae73cb Mon Sep 17 00:00:00 2001 From: Rapptz Date: Thu, 28 Apr 2022 12:15:58 -0400 Subject: [PATCH] Fix multiple view instances not dispatching in app commands responses Due to a quirk in InteractionResponse.send_message not returning a message, all messages sent with an associated View would end up having no message_id set. When multiple instances of a View are responded to in a slash command context, this meant that the newest one would override the storage of the older one. Ultimately leading to the first view instance causing interaction failures. Since fetching the original message is an unacceptable solution to the problem due to incurred requests, the next best thing is to store an intermediate interaction_id as a stop gap to differentiate between the multiple instances. This change however, came with its own set of complications. Due to the interaction_id being an intermediate stop gap, the underlying storage of the view store had to be changed to accommodate the different way of accessing the data. Mainly, the interaction_id key had to be quick to swap and remove to another key. This solution attempts to change the interaction_id interim key with a full fledged message_id key when it receives one if it's possible. Note that the only way to obtain the interaction_id back from the component interaction is to retrieve it from the MessageInteraction data structure. This is because using the interaction_id of the button press would be a different interaction ID than the one set as an interim key. As a consequence, this stop gap only works for application command based interactions. I am not aware of this bug manifesting in component based interactions. This patch also fixes a bug with ViewStore.remove_view not working due to a bug being suppressed by a type: ignore comment. It also removes the older __verify_integrity helper method since clean-up is already done through View.stop() or View timeout. Hopefully in the near future, the `/callback` endpoint will return actual message data and this stop gap fix will no longer be necessary. --- discord/interactions.py | 7 ++- discord/message.py | 6 ++- discord/state.py | 14 ++++-- discord/ui/view.py | 97 +++++++++++++++++++++++++++-------------- 4 files changed, 85 insertions(+), 39 deletions(-) diff --git a/discord/interactions.py b/discord/interactions.py index 5a3297cf6..317ca2f0c 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -454,7 +454,7 @@ class Interaction: state = _InteractionMessageState(self, self._state) message = InteractionMessage(state=state, channel=self.channel, data=data) # type: ignore if view and not view.is_finished(): - self._state.store_view(view, message.id) + self._state.store_view(view, message.id, interaction_id=self.id) return message async def delete_original_message(self) -> None: @@ -682,7 +682,10 @@ class InteractionResponse: if ephemeral and view.timeout is None: view.timeout = 15 * 60.0 - self._parent._state.store_view(view) + # If the interaction type isn't an application command then there's no way + # to obtain this interaction_id again, so just default to None + entity_id = parent.id if parent.type is InteractionType.application_command else None + self._parent._state.store_view(view, entity_id) self._responded = True diff --git a/discord/message.py b/discord/message.py index d672108ac..921c5cb2e 100644 --- a/discord/message.py +++ b/discord/message.py @@ -882,7 +882,11 @@ class PartialMessage(Hashable): message = Message(state=self._state, channel=self.channel, data=data) if view and not view.is_finished(): - self._state.store_view(view, self.id) + interaction: Optional[MessageInteraction] = getattr(self, 'interaction', None) + if interaction is not None: + self._state.store_view(view, self.id, interaction_id=interaction.id) + else: + self._state.store_view(view, self.id) if delete_after is not None: await self.delete(delay=delete_after) diff --git a/discord/state.py b/discord/state.py index c9e21658f..d9c1a3d3b 100644 --- a/discord/state.py +++ b/discord/state.py @@ -361,7 +361,9 @@ class ConnectionState: self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view: View, message_id: Optional[int] = None) -> None: + def store_view(self, view: View, message_id: Optional[int] = None, interaction_id: Optional[int] = None) -> None: + if interaction_id is not None: + self._view_store.remove_interaction_mapping(interaction_id) self._view_store.add_view(view, message_id) def prevent_view_updates_for(self, message_id: int) -> Optional[View]: @@ -631,8 +633,14 @@ class ConnectionState: else: self.dispatch('raw_message_edit', raw) - if 'components' in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data['components']) + if 'components' in data: + try: + entity_id = int(data['interaction']['id']) + except (KeyError, ValueError): + entity_id = raw.message_id + + if self._view_store.is_message_tracked(entity_id): + self._view_store.update_from_message(entity_id, data['components']) def parse_message_reaction_add(self, data: gw.MessageReactionAddEvent) -> None: emoji = PartialEmoji.from_dict(data['emoji']) diff --git a/discord/ui/view.py b/discord/ui/view.py index 0195b9277..d89386483 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -184,6 +184,7 @@ class View: self._children: List[Item[Self]] = self._init_children() self.__weights = _ViewWeights(self._children) self.id: str = os.urandom(16).hex() + self._cache_key: Optional[int] = None self.__cancel_callback: Optional[Callable[[View], None]] = None self.__timeout_expiry: Optional[float] = None self.__timeout_task: Optional[asyncio.Task[None]] = None @@ -512,8 +513,8 @@ class View: class ViewStore: def __init__(self, state: ConnectionState): - # (component_type, message_id, custom_id): (View, Item) - self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {} + # entity_id: {(component_type, custom_id): Item} + self._views: Dict[Optional[int], Dict[Tuple[int, str], Item[View]]] = {} # message_id: View self._synced_message_views: Dict[int, View] = {} # custom_id: Modal @@ -524,34 +525,26 @@ class ViewStore: def persistent_views(self) -> Sequence[View]: # fmt: off views = { - view.id: view - for (_, (view, _)) in self._views.items() - if view.is_persistent() + item.view.id: item.view + for items in self._views.values() + for item in items.values() + if item.view and item.view.is_persistent() } # fmt: on return list(views.values()) - def __verify_integrity(self): - to_remove: List[Tuple[int, Optional[int], str]] = [] - for (k, (view, _)) in self._views.items(): - if view.is_finished(): - to_remove.append(k) - - for k in to_remove: - del self._views[k] - def add_view(self, view: View, message_id: Optional[int] = None) -> None: view._start_listening_from_store(self) if view.__discord_ui_modal__: self._modals[view.custom_id] = view # type: ignore return - self.__verify_integrity() - + dispatch_info = self._views.setdefault(message_id, {}) for item in view._children: if item.is_dispatchable(): - self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore + dispatch_info[(item.type.value, item.custom_id)] = item # type: ignore + view._cache_key = message_id if message_id is not None: self._synced_message_views[message_id] = view @@ -560,28 +553,62 @@ class ViewStore: self._modals.pop(view.custom_id, None) # type: ignore return - for item in view._children: - if item.is_dispatchable(): - self._views.pop((item.type.value, item.custom_id), None) # type: ignore + dispatch_info = self._views.get(view._cache_key) + if dispatch_info: + for item in view._children: + if item.is_dispatchable(): + dispatch_info.pop((item.type.value, item.custom_id), None) # type: ignore + + if len(dispatch_info) == 0: + self._views.pop(view._cache_key, None) - for key, value in self._synced_message_views.items(): - if value.id == view.id: - del self._synced_message_views[key] - break + self._synced_message_views.pop(view._cache_key, None) # type: ignore def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None: - self.__verify_integrity() - message_id: Optional[int] = interaction.message and interaction.message.id - key = (component_type, message_id, custom_id) - # Fallback to None message_id searches in case a persistent view - # was added without an associated message_id - value = self._views.get(key) or self._views.get((component_type, None, custom_id)) - if value is None: + interaction_id: Optional[int] = None + message_id: Optional[int] = None + # Realistically, in a component based interaction the Interaction.message will never be None + # However, this guard is just in case Discord screws up somehow + msg = interaction.message + if msg is not None: + message_id = msg.id + if msg.interaction: + interaction_id = msg.interaction.id + + key = (component_type, custom_id) + + # The entity_id can either be message_id, interaction_id, or None in that priority order. + item: Optional[Item[View]] = None + if message_id is not None: + item = self._views.get(message_id, {}).get(key) + + if item is None and interaction_id is not None: + try: + items = self._views.pop(interaction_id) + except KeyError: + item = None + else: + item = items.get(key) + # If we actually got the items, then these keys should probably be moved + # to the proper message_id instead of the interaction_id as they are now. + # An interaction_id is only used as a temporary stop gap for + # InteractionResponse.send_message so multiple view instances do not + # override each other. + # NOTE: Fix this mess if /callback endpoint ever gets proper return types + self._views.setdefault(message_id, {}).update(items) + + if item is None: + # Fallback to None message_id searches in case a persistent view + # was added without an associated message_id + item = self._views.get(None, {}).get(key) + + # If 3 lookups failed at this point then just discard it + if item is None: return - view, item = value item._refresh_state(interaction.data) # type: ignore - view._dispatch_item(item, interaction) + # Note, at this point the View is *not* None + item.view._dispatch_item(item, interaction) # type: ignore def dispatch_modal( self, @@ -597,6 +624,10 @@ class ViewStore: modal._refresh(components) modal._dispatch_submit(interaction) + def remove_interaction_mapping(self, interaction_id: int) -> None: + # This is called before re-adding the view + self._views.pop(interaction_id, None) + def is_message_tracked(self, message_id: int) -> bool: return message_id in self._synced_message_views