diff --git a/discord/components.py b/discord/components.py index fe2093978..aa67b87ca 100644 --- a/discord/components.py +++ b/discord/components.py @@ -28,7 +28,7 @@ from asyncio import TimeoutError from datetime import datetime from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union -from .enums import try_enum, ComponentType, ButtonStyle +from .enums import try_enum, ComponentType, ButtonStyle, InteractionType from .errors import InvalidData from .utils import get_slots, MISSING, time_snowflake from .partial_emoji import PartialEmoji, _EmojiTag @@ -41,8 +41,11 @@ if TYPE_CHECKING: SelectOption as SelectOptionPayload, ActionRow as ActionRowPayload, ) + from .types.snowflake import Snowflake from .emoji import Emoji from .message import Message + from .state import ConnectionState + from .user import BaseUser __all__ = ( @@ -64,22 +67,61 @@ class Interaction: id: :class:`int` The interaction ID. nonce: Optional[Union[:class:`int`, :class:`str`]] - The interaction's nonce. + The interaction's nonce. Not always present. + name: Optional[:class:`str`] + The name of the application command, if applicable. + type: :class:`InteractionType` + The type of interaction. successful: Optional[:class:`bool`] Whether the interaction succeeded. - This is not immediately available, and is filled when Discord notifies us about the outcome of the interaction. + If this is your interaction, this is not immediately available. + It is filled when Discord notifies us about the outcome of the interaction. + user: :class:`User` + The user who initiated the interaction. """ - __slots__ = ('id', 'nonce', 'successful') + __slots__ = ('id', 'type', 'nonce', 'user', 'name', 'successful') - def __init__(self, *, id: int, nonce: Optional[Union[int, str]] = None) -> None: + def __init__( + self, + id: int, + type: int, + nonce: Optional[Snowflake] = None, + *, + user: BaseUser, + name: Optional[str] = None, + ) -> None: self.id = id self.nonce = nonce + self.type = try_enum(InteractionType, type) + self.user = user + self.name = name self.successful: Optional[bool] = None + @classmethod + def _from_self( + cls, *, id: Snowflake, type: int, nonce: Optional[Snowflake] = None, user: BaseUser + ) -> Interaction: + return cls(int(id), type, nonce, user=user) + + @classmethod + def _from_message( + cls, state: ConnectionState, *, id: Snowflake, type: int, user: BaseUser, **data: Dict[str, Any] + ) -> Interaction: + name = data.get('name') + user = state.store_user(user) + inst = cls(id, type, user=user, name=name) + inst.successful = True + return inst + def __repr__(self) -> str: s = self.successful - return f'' + return f'' + + def __bool__(self) -> bool: + if self.successful is not None: + return self.successful + raise TypeError('Interaction has not been resolved yet') class Component: @@ -268,10 +310,11 @@ class Button(Component): if message.guild: payload['guild_id'] = str(message.guild.id) + state._interactions[payload['nonce']] = 3 await state.http.interact(payload) try: i = await state.client.wait_for( - 'interaction', + 'interaction_finish', check=lambda d: d.nonce == payload['nonce'], timeout=5, ) @@ -384,10 +427,11 @@ class SelectMenu(Component): if message.guild: payload['guild_id'] = str(message.guild.id) + state._interactions[payload['nonce']] = 3 await state.http.interact(payload) try: i = await state.client.wait_for( - 'interaction', + 'interaction_finish', check=lambda d: d.nonce == payload['nonce'], timeout=5, ) diff --git a/discord/enums.py b/discord/enums.py index 498f5adfe..a9f430f8f 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -66,6 +66,8 @@ __all__ = ( 'RequiredActionType', 'ReportType', 'BrowserEnum', + 'ApplicationCommandType', + 'ApplicationCommandOptionType', ) @@ -206,7 +208,10 @@ class MessageType(Enum): call = 3 channel_name_change = 4 channel_icon_change = 5 + channel_pinned_message = 6 pins_add = 6 + member_join = 7 + user_join = 7 new_member = 7 premium_guild_subscription = 8 premium_guild_tier_1 = 9 @@ -220,9 +225,10 @@ class MessageType(Enum): guild_discovery_grace_period_final_warning = 17 thread_created = 18 reply = 19 - application_command = 20 + chat_input_command = 20 thread_starter_message = 21 guild_invite_reminder = 22 + context_menu_command = 23 class VoiceRegion(Enum): @@ -385,6 +391,7 @@ class NotificationLevel(Enum, comparable=True): def __int__(self): return self.value + class AuditLogActionCategory(Enum): create = 1 delete = 2 @@ -660,12 +667,36 @@ class InviteTarget(Enum): embedded_application = 2 -class InteractionType(Enum): +class InteractionType(Enum, comparable=True): ping = 1 application_command = 2 component = 3 +class ApplicationCommandType(Enum, comparable=True): + chat_input = 1 + chat = 1 + slash = 1 + user = 2 + message = 3 + + def __int__(self): + return self.value + + +class ApplicationCommandOptionType(Enum, comparable=True): + sub_command = 1 + sub_command_group = 2 + string = 3 + integer = 4 + boolean = 5 + user = 6 + channel = 7 + role = 8 + mentionable = 9 + number = 10 + + class VideoQualityMode(Enum): auto = 1 full = 2 diff --git a/discord/message.py b/discord/message.py index a7ceceb78..9eeef2248 100644 --- a/discord/message.py +++ b/discord/message.py @@ -29,7 +29,7 @@ import datetime import re import io from os import PathLike -from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type +from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, overload, TypeVar, Type from . import utils from .reaction import Reaction @@ -38,7 +38,7 @@ from .partial_emoji import PartialEmoji from .calls import CallMessage from .enums import MessageType, ChannelType, try_enum from .errors import InvalidArgument, HTTPException -from .components import _component_factory +from .components import _component_factory, Interaction from .embeds import Embed from .member import Member from .flags import MessageFlags @@ -71,7 +71,7 @@ if TYPE_CHECKING: from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel from .components import Component from .state import ConnectionState - from .channel import TextChannel, GroupChannel, DMChannel, PartialMessageable + from .channel import TextChannel from .mentions import AllowedMentions from .user import User from .role import Role @@ -97,8 +97,8 @@ def convert_emoji_reaction(emoji): if isinstance(emoji, PartialEmoji): return emoji._as_reaction() if isinstance(emoji, str): - # Reactions can be in :name:id format, but not <:name:id>. - # No existing emojis have <> in them, so this should be okay. + # Reactions can be in :name:id format, but not <:name:id> + # Emojis can't have <> in them, so this should be okay return emoji.strip('<>') raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.') @@ -151,9 +151,13 @@ class Attachment(Hashable): The attachment's `media type `_ .. versionadded:: 1.7 + description: Optional[:class:`str`] + The attachment's description (alt text). + + .. versionadded:: 2.0 """ - __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') + __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type', 'description') def __init__(self, *, data: AttachmentPayload, state: ConnectionState): self.id: int = int(data['id']) @@ -165,6 +169,7 @@ class Attachment(Hashable): self.proxy_url: str = data.get('proxy_url') self._http = state.http self.content_type: Optional[str] = data.get('content_type') + self.description: Optional[str] = data.get('description') def is_spoiler(self) -> bool: """:class:`bool`: Whether this attachment contains a spoiler.""" @@ -342,7 +347,7 @@ class DeletedReferencedMessage: @property def id(self) -> int: """:class:`int`: The message ID of the deleted referenced message.""" - # the parent's message id won't be None here + # The parent's message id won't be None here return self._parent.message_id # type: ignore @property @@ -459,7 +464,7 @@ class MessageReference: return f'' def to_dict(self) -> MessageReferencePayload: - result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} + result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} # type: ignore result['channel_id'] = self.channel_id if self.guild_id is not None: result['guild_id'] = self.guild_id @@ -478,7 +483,7 @@ def flatten_handlers(cls): if key.startswith('_handle_') and key != '_handle_member' ] - # store _handle_member last + # Store _handle_member last handlers.append(('member', cls._handle_member)) cls._HANDLERS = handlers cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')] @@ -605,6 +610,12 @@ class Message(Hashable): The guild that the message belongs to, if applicable. application_id: Optional[:class:`int`] The application that sent this message, if applicable. + + .. versionadded:: 2.0 + interaction: Optional[:class:`Interaction`] + The interaction the message is replying to, if applicable. + + .. versionadded:: 2.0 """ __slots__ = ( @@ -640,6 +651,7 @@ class Message(Hashable): 'components', 'guild', 'call', + 'interaction', ) if TYPE_CHECKING: @@ -681,7 +693,7 @@ class Message(Hashable): self.call: Optional[CallMessage] = None try: - # if the channel doesn't have a guild attribute, we handle that + # If the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore except AttributeError: self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) @@ -700,16 +712,16 @@ class Message(Hashable): if resolved is None: ref.resolved = DeletedReferencedMessage(ref) else: - # Right now the channel IDs match but maybe in the future they won't. + # Right now the channel IDs match but maybe in the future they won't if ref.channel_id == channel.id: chan = channel else: chan, _ = state._get_guild_channel(resolved) - # the channel will be the correct type here + # The channel will be the correct type here ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore - for handler in ('author', 'member', 'mentions', 'mention_roles', 'call'): + for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction'): try: getattr(self, f'_handle_{handler}')(data[handler]) except KeyError: @@ -900,6 +912,9 @@ class Message(Hashable): def _handle_components(self, components: List[ComponentPayload]): self.components = [_component_factory(d, self) for d in components] + def _handle_interaction(self, interaction: Dict[str, Any]): + self.interaction = Interaction._from_message(self._state, **interaction) + def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None: self.guild = new_guild self.channel = new_channel @@ -1015,7 +1030,8 @@ class Message(Hashable): return self.type not in ( MessageType.default, MessageType.reply, - MessageType.application_command, + MessageType.chat_input_command, + MessageType.context_menu_command, MessageType.thread_starter_message, ) @@ -1029,7 +1045,12 @@ class Message(Hashable): returns an English message denoting the contents of the system message. """ - if self.type in (MessageType.default, MessageType.reply): + if self.type in { + MessageType.default, + MessageType.reply, + MessageType.chat_input_command, + MessageType.context_menu_command, + }: return self.content if self.type is MessageType.recipient_add: @@ -1074,9 +1095,9 @@ class Message(Hashable): return formats[created_at_ms % len(formats)].format(self.author.name) if self.type is MessageType.call: - call_ended = self.call.ended_timestamp is not None + call_ended = self.call.ended_timestamp is not None # type: ignore - if self.channel.me in self.call.participants: + if self.channel.me in self.call.participants: # type: ignore return f'{self.author.name} started a call.' elif call_ended: return f'You missed a call from {self.author.name}' diff --git a/discord/reaction.py b/discord/reaction.py index 04eee3427..882d49321 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -166,7 +166,7 @@ class Reaction: Usage :: - # I do not actually recommend doing this. + # I do not actually recommend doing this async for user in reaction.users(): await channel.send(f'{user} has reacted with {reaction.emoji}!') diff --git a/discord/state.py b/discord/state.py index d3f325051..6e2b04abf 100644 --- a/discord/state.py +++ b/discord/state.py @@ -29,7 +29,7 @@ from collections import deque import copy import datetime import logging -from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque +from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Tuple, Deque import inspect import time import os @@ -52,7 +52,6 @@ from .role import Role from .enums import ChannelType, RequiredActionType, Status, try_enum, UnavailableGuildType, VoiceRegion from . import utils from .flags import GuildSubscriptionOptions, MemberCacheFlags -from .object import Object from .invite import Invite from .integrations import _integration_factory from .stage_instance import StageInstance @@ -264,7 +263,7 @@ class ConnectionState: self._voice_clients: Dict[int, VoiceProtocol] = {} self._voice_states: Dict[int, VoiceState] = {} - self._interactions: Dict[int, Interaction] = {} + self._interactions: Dict[Union[int, str], Union[int, Interaction]] = {} self._relationships: Dict[int, Relationship] = {} self._private_channels: Dict[int, PrivateChannel] = {} self._private_channels_by_user: Dict[int, DMChannel] = {} @@ -631,8 +630,8 @@ class ConnectionState: # Parsing that would require a redesign of the Relationship class ;-; # Self parsing - self.user = ClientUser(state=self, data=data['user']) - user = self.store_user(data['user']) + self.user = user = ClientUser(state=self, data=data['user']) + self.store_user(data['user']) # Temp user parsing temp_users = {user.id: user._to_minimal_user_json()} @@ -670,7 +669,7 @@ class ConnectionState: self.analytics_token = data.get('analytics_token') region = data.get('geo_ordered_rtc_regions', ['us-west'])[0] self.preferred_region = try_enum(VoiceRegion, region) - self.settings = settings = UserSettings(data=data.get('user_settings', {}), state=self) + self.settings = UserSettings(data=data.get('user_settings', {}), state=self) self.consents = Tracking(data.get('consents', {})) # We're done @@ -901,7 +900,7 @@ class ConnectionState: _log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) def parse_channel_create(self, data) -> None: - factory, ch_type = _channel_factory(data['type']) + factory, _ = _channel_factory(data['type']) if factory is None: _log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) return @@ -1609,10 +1608,6 @@ class ConnectionState: coro = vc.on_voice_server_update(data) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) - def parse_user_required_action_update(self, data) -> None: - required_action = try_enum(RequiredActionType, data['required_action']) - self.dispatch('required_action_update', required_action) - def parse_typing_start(self, data) -> None: channel, guild = self._get_guild_channel(data) if channel is not None: @@ -1661,9 +1656,10 @@ class ConnectionState: self.dispatch('relationship_remove', old) def parse_interaction_create(self, data) -> None: - i = Interaction(**data) + type = self._interactions.pop(data['nonce'], 0) + i = Interaction._from_self(type=type, user=self.user, **data) self._interactions[i.id] = i - self.dispatch('interaction', i) + self.dispatch('interaction_create', i) def parse_interaction_success(self, data) -> None: id = int(data['id'])