diff --git a/discord/interactions.py b/discord/interactions.py index 665ee92fc..c580a5e65 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -1071,24 +1071,18 @@ class InteractionResponse(Generic[ClientT]): proxy_auth=http.proxy_auth, params=params, ) - self._response_type = InteractionResponseType.channel_message - ret = InteractionCallbackResponse( - data=response, - parent=self._parent, - state=self._parent._state, - type=self._response_type, - ) if view is not MISSING and not view.is_finished(): if ephemeral and view.timeout is None: view.timeout = 15 * 60.0 - # this assertion should never fail because the resource of a send_message - # response will always be an InteractionMessage - assert isinstance(ret.resource, InteractionMessage) - entity_id = ret.resource.id if parent.type is InteractionType.application_command else None + # If the interaction type isn't an application command then there's no way + # to obtain this interaction_id again, so just default to None + entity_id = parent.id if parent.type is InteractionType.application_command else None self._parent._state.store_view(view, entity_id) + self._response_type = InteractionResponseType.channel_message + if delete_after is not None: async def inner_call(delay: float = delete_after): @@ -1100,7 +1094,12 @@ class InteractionResponse(Generic[ClientT]): asyncio.create_task(inner_call()) - return ret + return InteractionCallbackResponse( + data=response, + parent=self._parent, + state=self._parent._state, + type=self._response_type, + ) @overload async def edit_message( @@ -1209,7 +1208,14 @@ class InteractionResponse(Generic[ClientT]): parent = self._parent msg = parent.message state = parent._state - message_id = msg and msg.id + if msg is not None: + message_id = msg.id + # If this was invoked via an application command then we can use its original interaction ID + # Since this is used as a cache key for view updates + original_interaction_id = msg.interaction_metadata.id if msg.interaction_metadata is not None else None + else: + message_id = None + original_interaction_id = None if parent.type not in (InteractionType.component, InteractionType.modal_submit): return @@ -1247,7 +1253,7 @@ class InteractionResponse(Generic[ClientT]): ) if view and not view.is_finished(): - state.store_view(view, message_id) + state.store_view(view, message_id, interaction_id=original_interaction_id) self._response_type = InteractionResponseType.message_update diff --git a/discord/message.py b/discord/message.py index fd781aeee..0057b06f8 100644 --- a/discord/message.py +++ b/discord/message.py @@ -1444,7 +1444,11 @@ class PartialMessage(Hashable): message = Message(state=self._state, channel=self.channel, data=data) if view and not view.is_finished(): - self._state.store_view(view, self.id) + interaction: Optional[MessageInteraction] = getattr(self, 'interaction', None) + if interaction is not None: + self._state.store_view(view, self.id, interaction_id=interaction.id) + else: + self._state.store_view(view, self.id) if delete_after is not None: await self.delete(delay=delete_after) diff --git a/discord/state.py b/discord/state.py index 667c932bf..37bd138a7 100644 --- a/discord/state.py +++ b/discord/state.py @@ -412,7 +412,9 @@ class ConnectionState(Generic[ClientT]): self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view: BaseView, message_id: Optional[int] = None) -> None: + def store_view(self, view: BaseView, message_id: Optional[int] = None, interaction_id: Optional[int] = None) -> None: + if interaction_id is not None: + self._view_store.remove_interaction_mapping(interaction_id) self._view_store.add_view(view, message_id) def prevent_view_updates_for(self, message_id: int) -> Optional[BaseView]: @@ -733,7 +735,11 @@ class ConnectionState(Generic[ClientT]): self.dispatch('raw_message_edit', raw) if 'components' in data: - entity_id = raw.message_id + try: + entity_id = int(data['interaction']['id']) # pyright: ignore[reportTypedDictNotRequiredAccess] + except (KeyError, ValueError): + entity_id = raw.message_id + if self._view_store.is_message_tracked(entity_id): self._view_store.update_from_message(entity_id, data['components']) diff --git a/discord/ui/view.py b/discord/ui/view.py index d21ea5661..10519cf5a 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -1162,12 +1162,15 @@ class ViewStore: def dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None: self.dispatch_dynamic_items(component_type, custom_id, interaction) + interaction_id: Optional[int] = None message_id: Optional[int] = None # Realistically, in a component based interaction the Interaction.message will never be None # However, this guard is just in case Discord screws up somehow msg = interaction.message if msg is not None: message_id = msg.id + if msg.interaction_metadata: + interaction_id = msg.interaction_metadata.id key = (component_type, custom_id) @@ -1176,10 +1179,27 @@ class ViewStore: if message_id is not None: item = self._views.get(message_id, {}).get(key) + if item is None and interaction_id is not None: + try: + items = self._views.pop(interaction_id) + except KeyError: + item = None + else: + item = items.get(key) + # If we actually got the items, then these keys should probably be moved + # to the proper message_id instead of the interaction_id as they are now. + # An interaction_id is only used as a temporary stop gap for + # InteractionResponse.send_message so multiple view instances do not + # override each other. + # NOTE: Fix this mess if /callback endpoint ever gets proper return types + self._views.setdefault(message_id, {}).update(items) + if item is None: + # Fallback to None message_id searches in case a persistent view + # was added without an associated message_id item = self._views.get(None, {}).get(key) - # If 2 lookups failed at this point then just discard it + # If 3 lookups failed at this point then just discard it if item is None: return @@ -1199,6 +1219,11 @@ class ViewStore: modal._dispatch_submit(interaction, components) + def remove_interaction_mapping(self, interaction_id: int) -> None: + # This is called before re-adding the view + self._views.pop(interaction_id, None) + self._synced_message_views.pop(interaction_id, None) + def is_message_tracked(self, message_id: int) -> bool: return message_id in self._synced_message_views