Browse Source

Improve interaction object, implement message.interaction, implement Attachment.description

pull/10109/head
dolfies 4 years ago
parent
commit
397bca6b14
  1. 60
      discord/components.py
  2. 35
      discord/enums.py
  3. 55
      discord/message.py
  4. 2
      discord/reaction.py
  5. 22
      discord/state.py

60
discord/components.py

@ -28,7 +28,7 @@ from asyncio import TimeoutError
from datetime import datetime from datetime import datetime
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from .enums import try_enum, ComponentType, ButtonStyle from .enums import try_enum, ComponentType, ButtonStyle, InteractionType
from .errors import InvalidData from .errors import InvalidData
from .utils import get_slots, MISSING, time_snowflake from .utils import get_slots, MISSING, time_snowflake
from .partial_emoji import PartialEmoji, _EmojiTag from .partial_emoji import PartialEmoji, _EmojiTag
@ -41,8 +41,11 @@ if TYPE_CHECKING:
SelectOption as SelectOptionPayload, SelectOption as SelectOptionPayload,
ActionRow as ActionRowPayload, ActionRow as ActionRowPayload,
) )
from .types.snowflake import Snowflake
from .emoji import Emoji from .emoji import Emoji
from .message import Message from .message import Message
from .state import ConnectionState
from .user import BaseUser
__all__ = ( __all__ = (
@ -64,22 +67,61 @@ class Interaction:
id: :class:`int` id: :class:`int`
The interaction ID. The interaction ID.
nonce: Optional[Union[:class:`int`, :class:`str`]] nonce: Optional[Union[:class:`int`, :class:`str`]]
The interaction's nonce. The interaction's nonce. Not always present.
name: Optional[:class:`str`]
The name of the application command, if applicable.
type: :class:`InteractionType`
The type of interaction.
successful: Optional[:class:`bool`] successful: Optional[:class:`bool`]
Whether the interaction succeeded. Whether the interaction succeeded.
This is not immediately available, and is filled when Discord notifies us about the outcome of the interaction. If this is your interaction, this is not immediately available.
It is filled when Discord notifies us about the outcome of the interaction.
user: :class:`User`
The user who initiated the interaction.
""" """
__slots__ = ('id', 'nonce', 'successful') __slots__ = ('id', 'type', 'nonce', 'user', 'name', 'successful')
def __init__(self, *, id: int, nonce: Optional[Union[int, str]] = None) -> None: def __init__(
self,
id: int,
type: int,
nonce: Optional[Snowflake] = None,
*,
user: BaseUser,
name: Optional[str] = None,
) -> None:
self.id = id self.id = id
self.nonce = nonce self.nonce = nonce
self.type = try_enum(InteractionType, type)
self.user = user
self.name = name
self.successful: Optional[bool] = None self.successful: Optional[bool] = None
@classmethod
def _from_self(
cls, *, id: Snowflake, type: int, nonce: Optional[Snowflake] = None, user: BaseUser
) -> Interaction:
return cls(int(id), type, nonce, user=user)
@classmethod
def _from_message(
cls, state: ConnectionState, *, id: Snowflake, type: int, user: BaseUser, **data: Dict[str, Any]
) -> Interaction:
name = data.get('name')
user = state.store_user(user)
inst = cls(id, type, user=user, name=name)
inst.successful = True
return inst
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.successful s = self.successful
return f'<Interaction id={self.id}{f" successful={s}" if s is not None else ""}>' return f'<Interaction id={self.id} type={self.type}{f" successful={s}" if s is not None else ""} user={self.user!r}>'
def __bool__(self) -> bool:
if self.successful is not None:
return self.successful
raise TypeError('Interaction has not been resolved yet')
class Component: class Component:
@ -268,10 +310,11 @@ class Button(Component):
if message.guild: if message.guild:
payload['guild_id'] = str(message.guild.id) payload['guild_id'] = str(message.guild.id)
state._interactions[payload['nonce']] = 3
await state.http.interact(payload) await state.http.interact(payload)
try: try:
i = await state.client.wait_for( i = await state.client.wait_for(
'interaction', 'interaction_finish',
check=lambda d: d.nonce == payload['nonce'], check=lambda d: d.nonce == payload['nonce'],
timeout=5, timeout=5,
) )
@ -384,10 +427,11 @@ class SelectMenu(Component):
if message.guild: if message.guild:
payload['guild_id'] = str(message.guild.id) payload['guild_id'] = str(message.guild.id)
state._interactions[payload['nonce']] = 3
await state.http.interact(payload) await state.http.interact(payload)
try: try:
i = await state.client.wait_for( i = await state.client.wait_for(
'interaction', 'interaction_finish',
check=lambda d: d.nonce == payload['nonce'], check=lambda d: d.nonce == payload['nonce'],
timeout=5, timeout=5,
) )

35
discord/enums.py

@ -66,6 +66,8 @@ __all__ = (
'RequiredActionType', 'RequiredActionType',
'ReportType', 'ReportType',
'BrowserEnum', 'BrowserEnum',
'ApplicationCommandType',
'ApplicationCommandOptionType',
) )
@ -206,7 +208,10 @@ class MessageType(Enum):
call = 3 call = 3
channel_name_change = 4 channel_name_change = 4
channel_icon_change = 5 channel_icon_change = 5
channel_pinned_message = 6
pins_add = 6 pins_add = 6
member_join = 7
user_join = 7
new_member = 7 new_member = 7
premium_guild_subscription = 8 premium_guild_subscription = 8
premium_guild_tier_1 = 9 premium_guild_tier_1 = 9
@ -220,9 +225,10 @@ class MessageType(Enum):
guild_discovery_grace_period_final_warning = 17 guild_discovery_grace_period_final_warning = 17
thread_created = 18 thread_created = 18
reply = 19 reply = 19
application_command = 20 chat_input_command = 20
thread_starter_message = 21 thread_starter_message = 21
guild_invite_reminder = 22 guild_invite_reminder = 22
context_menu_command = 23
class VoiceRegion(Enum): class VoiceRegion(Enum):
@ -385,6 +391,7 @@ class NotificationLevel(Enum, comparable=True):
def __int__(self): def __int__(self):
return self.value return self.value
class AuditLogActionCategory(Enum): class AuditLogActionCategory(Enum):
create = 1 create = 1
delete = 2 delete = 2
@ -660,12 +667,36 @@ class InviteTarget(Enum):
embedded_application = 2 embedded_application = 2
class InteractionType(Enum): class InteractionType(Enum, comparable=True):
ping = 1 ping = 1
application_command = 2 application_command = 2
component = 3 component = 3
class ApplicationCommandType(Enum, comparable=True):
chat_input = 1
chat = 1
slash = 1
user = 2
message = 3
def __int__(self):
return self.value
class ApplicationCommandOptionType(Enum, comparable=True):
sub_command = 1
sub_command_group = 2
string = 3
integer = 4
boolean = 5
user = 6
channel = 7
role = 8
mentionable = 9
number = 10
class VideoQualityMode(Enum): class VideoQualityMode(Enum):
auto = 1 auto = 1
full = 2 full = 2

55
discord/message.py

@ -29,7 +29,7 @@ import datetime
import re import re
import io import io
from os import PathLike from os import PathLike
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, overload, TypeVar, Type
from . import utils from . import utils
from .reaction import Reaction from .reaction import Reaction
@ -38,7 +38,7 @@ from .partial_emoji import PartialEmoji
from .calls import CallMessage from .calls import CallMessage
from .enums import MessageType, ChannelType, try_enum from .enums import MessageType, ChannelType, try_enum
from .errors import InvalidArgument, HTTPException from .errors import InvalidArgument, HTTPException
from .components import _component_factory from .components import _component_factory, Interaction
from .embeds import Embed from .embeds import Embed
from .member import Member from .member import Member
from .flags import MessageFlags from .flags import MessageFlags
@ -71,7 +71,7 @@ if TYPE_CHECKING:
from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
from .components import Component from .components import Component
from .state import ConnectionState from .state import ConnectionState
from .channel import TextChannel, GroupChannel, DMChannel, PartialMessageable from .channel import TextChannel
from .mentions import AllowedMentions from .mentions import AllowedMentions
from .user import User from .user import User
from .role import Role from .role import Role
@ -97,8 +97,8 @@ def convert_emoji_reaction(emoji):
if isinstance(emoji, PartialEmoji): if isinstance(emoji, PartialEmoji):
return emoji._as_reaction() return emoji._as_reaction()
if isinstance(emoji, str): if isinstance(emoji, str):
# Reactions can be in :name:id format, but not <:name:id>. # Reactions can be in :name:id format, but not <:name:id>
# No existing emojis have <> in them, so this should be okay. # Emojis can't have <> in them, so this should be okay
return emoji.strip('<>') return emoji.strip('<>')
raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.') raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.')
@ -151,9 +151,13 @@ class Attachment(Hashable):
The attachment's `media type <https://en.wikipedia.org/wiki/Media_type>`_ The attachment's `media type <https://en.wikipedia.org/wiki/Media_type>`_
.. versionadded:: 1.7 .. versionadded:: 1.7
description: Optional[:class:`str`]
The attachment's description (alt text).
.. versionadded:: 2.0
""" """
__slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type', 'description')
def __init__(self, *, data: AttachmentPayload, state: ConnectionState): def __init__(self, *, data: AttachmentPayload, state: ConnectionState):
self.id: int = int(data['id']) self.id: int = int(data['id'])
@ -165,6 +169,7 @@ class Attachment(Hashable):
self.proxy_url: str = data.get('proxy_url') self.proxy_url: str = data.get('proxy_url')
self._http = state.http self._http = state.http
self.content_type: Optional[str] = data.get('content_type') self.content_type: Optional[str] = data.get('content_type')
self.description: Optional[str] = data.get('description')
def is_spoiler(self) -> bool: def is_spoiler(self) -> bool:
""":class:`bool`: Whether this attachment contains a spoiler.""" """:class:`bool`: Whether this attachment contains a spoiler."""
@ -342,7 +347,7 @@ class DeletedReferencedMessage:
@property @property
def id(self) -> int: def id(self) -> int:
""":class:`int`: The message ID of the deleted referenced message.""" """:class:`int`: The message ID of the deleted referenced message."""
# the parent's message id won't be None here # The parent's message id won't be None here
return self._parent.message_id # type: ignore return self._parent.message_id # type: ignore
@property @property
@ -459,7 +464,7 @@ class MessageReference:
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>' return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
def to_dict(self) -> MessageReferencePayload: def to_dict(self) -> MessageReferencePayload:
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} # type: ignore
result['channel_id'] = self.channel_id result['channel_id'] = self.channel_id
if self.guild_id is not None: if self.guild_id is not None:
result['guild_id'] = self.guild_id result['guild_id'] = self.guild_id
@ -478,7 +483,7 @@ def flatten_handlers(cls):
if key.startswith('_handle_') and key != '_handle_member' if key.startswith('_handle_') and key != '_handle_member'
] ]
# store _handle_member last # Store _handle_member last
handlers.append(('member', cls._handle_member)) handlers.append(('member', cls._handle_member))
cls._HANDLERS = handlers cls._HANDLERS = handlers
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')] cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')]
@ -605,6 +610,12 @@ class Message(Hashable):
The guild that the message belongs to, if applicable. The guild that the message belongs to, if applicable.
application_id: Optional[:class:`int`] application_id: Optional[:class:`int`]
The application that sent this message, if applicable. The application that sent this message, if applicable.
.. versionadded:: 2.0
interaction: Optional[:class:`Interaction`]
The interaction the message is replying to, if applicable.
.. versionadded:: 2.0
""" """
__slots__ = ( __slots__ = (
@ -640,6 +651,7 @@ class Message(Hashable):
'components', 'components',
'guild', 'guild',
'call', 'call',
'interaction',
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -681,7 +693,7 @@ class Message(Hashable):
self.call: Optional[CallMessage] = None self.call: Optional[CallMessage] = None
try: try:
# if the channel doesn't have a guild attribute, we handle that # If the channel doesn't have a guild attribute, we handle that
self.guild = channel.guild # type: ignore self.guild = channel.guild # type: ignore
except AttributeError: except AttributeError:
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
@ -700,16 +712,16 @@ class Message(Hashable):
if resolved is None: if resolved is None:
ref.resolved = DeletedReferencedMessage(ref) ref.resolved = DeletedReferencedMessage(ref)
else: else:
# Right now the channel IDs match but maybe in the future they won't. # Right now the channel IDs match but maybe in the future they won't
if ref.channel_id == channel.id: if ref.channel_id == channel.id:
chan = channel chan = channel
else: else:
chan, _ = state._get_guild_channel(resolved) chan, _ = state._get_guild_channel(resolved)
# the channel will be the correct type here # The channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles', 'call'): for handler in ('author', 'member', 'mentions', 'mention_roles', 'call', 'interaction'):
try: try:
getattr(self, f'_handle_{handler}')(data[handler]) getattr(self, f'_handle_{handler}')(data[handler])
except KeyError: except KeyError:
@ -900,6 +912,9 @@ class Message(Hashable):
def _handle_components(self, components: List[ComponentPayload]): def _handle_components(self, components: List[ComponentPayload]):
self.components = [_component_factory(d, self) for d in components] self.components = [_component_factory(d, self) for d in components]
def _handle_interaction(self, interaction: Dict[str, Any]):
self.interaction = Interaction._from_message(self._state, **interaction)
def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None: def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None:
self.guild = new_guild self.guild = new_guild
self.channel = new_channel self.channel = new_channel
@ -1015,7 +1030,8 @@ class Message(Hashable):
return self.type not in ( return self.type not in (
MessageType.default, MessageType.default,
MessageType.reply, MessageType.reply,
MessageType.application_command, MessageType.chat_input_command,
MessageType.context_menu_command,
MessageType.thread_starter_message, MessageType.thread_starter_message,
) )
@ -1029,7 +1045,12 @@ class Message(Hashable):
returns an English message denoting the contents of the system message. returns an English message denoting the contents of the system message.
""" """
if self.type in (MessageType.default, MessageType.reply): if self.type in {
MessageType.default,
MessageType.reply,
MessageType.chat_input_command,
MessageType.context_menu_command,
}:
return self.content return self.content
if self.type is MessageType.recipient_add: if self.type is MessageType.recipient_add:
@ -1074,9 +1095,9 @@ class Message(Hashable):
return formats[created_at_ms % len(formats)].format(self.author.name) return formats[created_at_ms % len(formats)].format(self.author.name)
if self.type is MessageType.call: if self.type is MessageType.call:
call_ended = self.call.ended_timestamp is not None call_ended = self.call.ended_timestamp is not None # type: ignore
if self.channel.me in self.call.participants: if self.channel.me in self.call.participants: # type: ignore
return f'{self.author.name} started a call.' return f'{self.author.name} started a call.'
elif call_ended: elif call_ended:
return f'You missed a call from {self.author.name}' return f'You missed a call from {self.author.name}'

2
discord/reaction.py

@ -166,7 +166,7 @@ class Reaction:
Usage :: Usage ::
# I do not actually recommend doing this. # I do not actually recommend doing this
async for user in reaction.users(): async for user in reaction.users():
await channel.send(f'{user} has reacted with {reaction.emoji}!') await channel.send(f'{user} has reacted with {reaction.emoji}!')

22
discord/state.py

@ -29,7 +29,7 @@ from collections import deque
import copy import copy
import datetime import datetime
import logging import logging
from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Tuple, Deque
import inspect import inspect
import time import time
import os import os
@ -52,7 +52,6 @@ from .role import Role
from .enums import ChannelType, RequiredActionType, Status, try_enum, UnavailableGuildType, VoiceRegion from .enums import ChannelType, RequiredActionType, Status, try_enum, UnavailableGuildType, VoiceRegion
from . import utils from . import utils
from .flags import GuildSubscriptionOptions, MemberCacheFlags from .flags import GuildSubscriptionOptions, MemberCacheFlags
from .object import Object
from .invite import Invite from .invite import Invite
from .integrations import _integration_factory from .integrations import _integration_factory
from .stage_instance import StageInstance from .stage_instance import StageInstance
@ -264,7 +263,7 @@ class ConnectionState:
self._voice_clients: Dict[int, VoiceProtocol] = {} self._voice_clients: Dict[int, VoiceProtocol] = {}
self._voice_states: Dict[int, VoiceState] = {} self._voice_states: Dict[int, VoiceState] = {}
self._interactions: Dict[int, Interaction] = {} self._interactions: Dict[Union[int, str], Union[int, Interaction]] = {}
self._relationships: Dict[int, Relationship] = {} self._relationships: Dict[int, Relationship] = {}
self._private_channels: Dict[int, PrivateChannel] = {} self._private_channels: Dict[int, PrivateChannel] = {}
self._private_channels_by_user: Dict[int, DMChannel] = {} self._private_channels_by_user: Dict[int, DMChannel] = {}
@ -631,8 +630,8 @@ class ConnectionState:
# Parsing that would require a redesign of the Relationship class ;-; # Parsing that would require a redesign of the Relationship class ;-;
# Self parsing # Self parsing
self.user = ClientUser(state=self, data=data['user']) self.user = user = ClientUser(state=self, data=data['user'])
user = self.store_user(data['user']) self.store_user(data['user'])
# Temp user parsing # Temp user parsing
temp_users = {user.id: user._to_minimal_user_json()} temp_users = {user.id: user._to_minimal_user_json()}
@ -670,7 +669,7 @@ class ConnectionState:
self.analytics_token = data.get('analytics_token') self.analytics_token = data.get('analytics_token')
region = data.get('geo_ordered_rtc_regions', ['us-west'])[0] region = data.get('geo_ordered_rtc_regions', ['us-west'])[0]
self.preferred_region = try_enum(VoiceRegion, region) self.preferred_region = try_enum(VoiceRegion, region)
self.settings = settings = UserSettings(data=data.get('user_settings', {}), state=self) self.settings = UserSettings(data=data.get('user_settings', {}), state=self)
self.consents = Tracking(data.get('consents', {})) self.consents = Tracking(data.get('consents', {}))
# We're done # We're done
@ -901,7 +900,7 @@ class ConnectionState:
_log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) _log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
def parse_channel_create(self, data) -> None: def parse_channel_create(self, data) -> None:
factory, ch_type = _channel_factory(data['type']) factory, _ = _channel_factory(data['type'])
if factory is None: if factory is None:
_log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) _log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type'])
return return
@ -1609,10 +1608,6 @@ class ConnectionState:
coro = vc.on_voice_server_update(data) coro = vc.on_voice_server_update(data)
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
def parse_user_required_action_update(self, data) -> None:
required_action = try_enum(RequiredActionType, data['required_action'])
self.dispatch('required_action_update', required_action)
def parse_typing_start(self, data) -> None: def parse_typing_start(self, data) -> None:
channel, guild = self._get_guild_channel(data) channel, guild = self._get_guild_channel(data)
if channel is not None: if channel is not None:
@ -1661,9 +1656,10 @@ class ConnectionState:
self.dispatch('relationship_remove', old) self.dispatch('relationship_remove', old)
def parse_interaction_create(self, data) -> None: def parse_interaction_create(self, data) -> None:
i = Interaction(**data) type = self._interactions.pop(data['nonce'], 0)
i = Interaction._from_self(type=type, user=self.user, **data)
self._interactions[i.id] = i self._interactions[i.id] = i
self.dispatch('interaction', i) self.dispatch('interaction_create', i)
def parse_interaction_success(self, data) -> None: def parse_interaction_success(self, data) -> None:
id = int(data['id']) id = int(data['id'])

Loading…
Cancel
Save