From a86d42ce55e571cf7c2d4937839b5d5b5a77aaa3 Mon Sep 17 00:00:00 2001 From: dolfies Date: Mon, 24 Jan 2022 20:15:56 -0500 Subject: [PATCH] Fix interactions --- discord/commands.py | 24 ++++++++++++++++++------ discord/components.py | 10 ++++++---- discord/interactions.py | 2 +- discord/state.py | 4 ++-- discord/utils.py | 10 ++++++++++ 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/discord/commands.py b/discord/commands.py index 34ca6d30c..924892760 100644 --- a/discord/commands.py +++ b/discord/commands.py @@ -29,7 +29,7 @@ from typing import Any, Dict, List, Optional, Protocol, Tuple, runtime_checkable from .enums import CommandType, ChannelType, OptionType, try_enum from .errors import InvalidData, InvalidArgument -from .utils import time_snowflake +from .utils import _generate_session_id, time_snowflake if TYPE_CHECKING: from .abc import Messageable, Snowflake @@ -99,6 +99,7 @@ class ApplicationCommand(Protocol): 'channel_id': str(channel.id), 'data': data, 'nonce': str(time_snowflake(datetime.utcnow())), + 'session_id': state.session_id or _generate_session_id(), 'type': 2, # Should be an enum but eh } if getattr(channel, 'guild', None) is not None: @@ -143,19 +144,21 @@ class BaseCommand(ApplicationCommand): 'version', 'type', 'default_permission', - '_dm_permission', - '_default_member_permissions', + '_data', '_state', '_channel', - '_application_id' + '_application_id', + '_dm_permission', + '_default_member_permissions', ) def __init__( self, *, state: ConnectionState, data: Dict[str, Any], channel: Optional[Messageable] = None ) -> None: + self._state = state + self._data = data self.name = data['name'] self.description = data['description'] - self._state = state self._channel = channel self._application_id: int = int(data['application_id']) self.id: int = int(data['id']) @@ -202,7 +205,7 @@ class BaseCommand(ApplicationCommand): self._channel = value -class SlashMixin(ApplicationCommand): +class SlashMixin(ApplicationCommand, Protocol): if TYPE_CHECKING: _parent: SlashCommand options: List[Option] @@ -210,7 +213,10 @@ class SlashMixin(ApplicationCommand): async def __call__(self, options, channel=None): obj = self._parent + command = obj._data + command['name_localized'] = command['name'] data = { + 'application_command': command, 'attachments': [], 'id': str(obj.id), 'name': obj.name, @@ -299,7 +305,10 @@ class UserCommand(BaseCommand): if user is None: raise TypeError('__call__() missing 1 required positional argument: \'user\'') + command = self._data + command['name_localized'] = command['name'] data = { + 'application_command': command, 'attachments': [], 'id': str(self.id), 'name': self.name, @@ -367,7 +376,10 @@ class MessageCommand(BaseCommand): if message is None: raise TypeError('__call__() missing 1 required positional argument: \'message\'') + command = self._data + command['name_localized'] = command['name'] data = { + 'application_command': command, 'attachments': [], 'id': str(self.id), 'name': self.name, diff --git a/discord/components.py b/discord/components.py index 874a59102..3ea527134 100644 --- a/discord/components.py +++ b/discord/components.py @@ -30,7 +30,7 @@ from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Ty from .enums import try_enum, ComponentType, ButtonStyle from .errors import InvalidData -from .utils import get_slots, MISSING, time_snowflake +from .utils import _generate_session_id, get_slots, MISSING, time_snowflake from .partial_emoji import PartialEmoji, _EmojiTag if TYPE_CHECKING: @@ -228,7 +228,7 @@ class Button(Component): message = self.message state = message._state payload = { - 'application_id': str(message.application_id), + 'application_id': str(message.application_id or message.author.id), 'channel_id': str(message.channel.id), 'data': { 'component_type': 2, @@ -237,12 +237,13 @@ class Button(Component): 'message_flags': message.flags.value, 'message_id': str(message.id), 'nonce': str(time_snowflake(datetime.utcnow())), + 'session_id': state.session_id or _generate_session_id(), 'type': 3, # Should be an enum but eh } if message.guild: payload['guild_id'] = str(message.guild.id) - state._interactions[payload['nonce']] = 3 + state._interactions[payload['nonce']] = (3, None) await state.http.interact(payload) try: i = await state.client.wait_for( @@ -354,12 +355,13 @@ class SelectMenu(Component): 'message_flags': message.flags.value, 'message_id': str(message.id), 'nonce': str(time_snowflake(datetime.utcnow())), + 'session_id': state.session_id or _generate_session_id(), 'type': 3, # Should be an enum but eh } if message.guild: payload['guild_id'] = str(message.guild.id) - state._interactions[payload['nonce']] = 3 + state._interactions[payload['nonce']] = (3, None) await state.http.interact(payload) try: i = await state.client.wait_for( diff --git a/discord/interactions.py b/discord/interactions.py index 98224cd9d..38b248264 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -76,7 +76,7 @@ class Interaction: @classmethod def _from_self( - cls, *, id: Snowflake, type: int, nonce: Optional[Snowflake] = None, user: ClientUser, name: str + cls, *, id: Snowflake, type: int, nonce: Optional[Snowflake] = None, user: ClientUser, name: Optional[str] ) -> Interaction: return cls(int(id), type, nonce, user=user, name=name) diff --git a/discord/state.py b/discord/state.py index 35fd76d45..c7c2ac88d 100644 --- a/discord/state.py +++ b/discord/state.py @@ -262,7 +262,7 @@ class ConnectionState: self._voice_clients: Dict[int, VoiceProtocol] = {} self._voice_states: Dict[int, VoiceState] = {} - self._interactions: Dict[Union[int, str], Union[Tuple[int, str], Interaction]] = {} + self._interactions: Dict[Union[int, str], Union[Tuple[int, Optional[str]], Interaction]] = {} self._relationships: Dict[int, Relationship] = {} self._private_channels: Dict[int, PrivateChannel] = {} self._private_channels_by_user: Dict[int, DMChannel] = {} @@ -1722,7 +1722,7 @@ class ConnectionState: self.dispatch('relationship_remove', old) def parse_interaction_create(self, data) -> None: - type, name = self._interactions.pop(data['nonce'], (0, '')) + type, name = self._interactions.pop(data['nonce'], (0, None)) i = Interaction._from_self(type=type, user=self.user, name=name, **data) # type: ignore self._interactions[i.id] = i self.dispatch('interaction_create', i) diff --git a/discord/utils.py b/discord/utils.py index 6bb9965ec..65dcd50db 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -59,7 +59,9 @@ import json import logging import os import platform +import random import re +import string import subprocess import sys import tempfile @@ -1067,6 +1069,14 @@ def set_target( except AttributeError: pass + +def _generate_session_id() -> str: + return ''.join( + random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) + for _ in range(16) + ) + + class ExpiringQueue(asyncio.Queue): # Inspired from https://github.com/NoahCardoza/CaptchaHarvester def __init__(self, timeout: int, maxsize: int = 0) -> None: super().__init__(maxsize)