From e168fd38a009cd3965f01bf79d9844ab661adacb Mon Sep 17 00:00:00 2001 From: dolfies Date: Sat, 13 Nov 2021 18:22:26 -0500 Subject: [PATCH] Implement button clicking --- discord/components.py | 59 +++++++++++++++++++++++++++++++++++-------- discord/http.py | 5 +++- discord/message.py | 25 ++++-------------- 3 files changed, 58 insertions(+), 31 deletions(-) diff --git a/discord/components.py b/discord/components.py index 491716076..116b80c88 100644 --- a/discord/components.py +++ b/discord/components.py @@ -24,9 +24,11 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations +from datetime import datetime from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union + from .enums import try_enum, ComponentType, ButtonStyle -from .utils import get_slots, MISSING +from .utils import get_slots, MISSING, time_snowflake from .partial_emoji import PartialEmoji, _EmojiTag if TYPE_CHECKING: @@ -38,6 +40,7 @@ if TYPE_CHECKING: ActionRow as ActionRowPayload, ) from .emoji import Emoji + from .message import Message __all__ = ( @@ -70,10 +73,11 @@ class Component: The type of component. """ - __slots__: Tuple[str, ...] = ('type',) + __slots__: Tuple[str, ...] = ('type', 'message') __repr_info__: ClassVar[Tuple[str, ...]] type: ComponentType + message: Message def __repr__(self) -> str: attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) @@ -110,15 +114,18 @@ class ActionRow(Component): The type of component. children: List[:class:`Component`] The children components that this holds, if any. + message: :class:`Message` + The originating message. """ __slots__: Tuple[str, ...] = ('children',) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: ComponentPayload): + def __init__(self, data: ComponentPayload, message: Message): + self.message = message self.type: ComponentType = try_enum(ComponentType, data['type']) - self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self.children: List[Component] = [_component_factory(d, message) for d in data.get('components', [])] def to_dict(self) -> ActionRowPayload: return { @@ -149,6 +156,8 @@ class Button(Component): The label of the button, if any. emoji: Optional[:class:`PartialEmoji`] The emoji of the button, if available. + message: :class:`Message` + The originating message, if any. """ __slots__: Tuple[str, ...] = ( @@ -162,7 +171,8 @@ class Button(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: ButtonComponentPayload): + def __init__(self, data: ButtonComponentPayload, message: Message): + self.message = message self.type: ComponentType = try_enum(ComponentType, data['type']) self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.custom_id: Optional[str] = data.get('custom_id') @@ -193,6 +203,32 @@ class Button(Component): return payload # type: ignore + async def click(self): + """|coro| + + Clicks the button. + + Raises + ------- + HTTPException + Clicking the button failed. + """ + message = self.message + payload = { + 'application_id': str(message.application_id), + 'channel_id': str(message.channel.id), + 'data': { + 'component_type': 2, + 'custom_id': self.custom_id, + }, + 'guild_id': message.guild and str(message.guild.id), + 'message_flags': message.flags.value, + 'message_id': str(message.id), + 'nonce': str(time_snowflake(datetime.utcnow())), + 'type': 3, # Should be an enum but eh + } + await message._state.http.interact(payload) # type: ignore + class SelectMenu(Component): """Represents a select menu from the Discord Bot UI Kit. @@ -218,6 +254,8 @@ class SelectMenu(Component): A list of options that can be selected in this menu. disabled: :class:`bool` Whether the select is disabled or not. + message: :class:`Message` + The originating message, if any. """ __slots__: Tuple[str, ...] = ( @@ -231,7 +269,8 @@ class SelectMenu(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: SelectMenuPayload): + def __init__(self, data: SelectMenuPayload, message: Message): + self.message = message self.type = ComponentType.select self.custom_id: str = data['custom_id'] self.placeholder: Optional[str] = data.get('placeholder') @@ -358,14 +397,14 @@ class SelectOption: return payload -def _component_factory(data: ComponentPayload) -> Component: +def _component_factory(data: ComponentPayload, message: Message) -> Component: component_type = data['type'] if component_type == 1: - return ActionRow(data) + return ActionRow(data, message) elif component_type == 2: - return Button(data) # type: ignore + return Button(data, message) # type: ignore elif component_type == 3: - return SelectMenu(data) # type: ignore + return SelectMenu(data, message) # type: ignore else: as_enum = try_enum(ComponentType, component_type) return Component._raw_construct(type=as_enum) diff --git a/discord/http.py b/discord/http.py index 27b51e21e..3ac907ec1 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1879,4 +1879,7 @@ class HTTPClient: 'reason': reason } - return self.request(Route('POST', '/report'), json=payload) \ No newline at end of file + return self.request(Route('POST', '/report'), json=payload) + + def interact(self, data) -> Response[None]: + return self.request(Route('POST', '/interactions'), json=data) \ No newline at end of file diff --git a/discord/message.py b/discord/message.py index 6b1dad811..635cd7104 100644 --- a/discord/message.py +++ b/discord/message.py @@ -75,7 +75,6 @@ if TYPE_CHECKING: from .mentions import AllowedMentions from .user import User from .role import Role - from .ui.view import View MR = TypeVar('MR', bound='MessageReference') EmojiInputType = Union[Emoji, PartialEmoji, str] @@ -604,6 +603,8 @@ class Message(Hashable): .. versionadded:: 2.0 guild: Optional[:class:`Guild`] The guild that the message belongs to, if applicable. + application_id: Optional[:class:`int`] + The application that sent this message, if applicable. """ __slots__ = ( @@ -619,6 +620,7 @@ class Message(Hashable): 'content', 'channel', 'webhook_id', + 'application_id', 'mention_everyone', 'embeds', 'id', @@ -659,6 +661,7 @@ class Message(Hashable): self._state: ConnectionState = state self.id: int = int(data['id']) self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id') + self.application_id: Optional[int] = utils._get_as_snowflake(data, 'application_id') self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])] self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']] self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] @@ -674,7 +677,7 @@ class Message(Hashable): self.content: str = data['content'] self.nonce: Optional[Union[int, str]] = data.get('nonce') self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] - self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self.components: List[Component] = [_component_factory(d, self) for d in data.get('components', [])] self.call: Optional[CallMessage] = None try: @@ -1757,11 +1760,6 @@ class PartialMessage(Hashable): to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` are used instead. - view: Optional[:class:`~discord.ui.View`] - The updated view to update this message with. If ``None`` is passed then - the view is removed. - - .. versionadded:: 2.0 Raises ------- @@ -1818,17 +1816,6 @@ class PartialMessage(Hashable): allowed_mentions = allowed_mentions.to_dict() fields['allowed_mentions'] = allowed_mentions - try: - view = fields.pop('view') - except KeyError: - # To check for the view afterwards - view = None - else: - self._state.prevent_view_updates_for(self.id) - if view: - fields['components'] = view.to_components() - else: - fields['components'] = [] if fields: data = await self._state.http.edit_message(self.channel.id, self.id, **fields) @@ -1839,6 +1826,4 @@ class PartialMessage(Hashable): if fields: # data isn't unbound msg = self._state.create_message(channel=self.channel, data=data) # type: ignore - if view and not view.is_finished(): - self._state.store_view(view, self.id) return msg