From 92d1b4cd2b8cda4bb7f9c39acbf868ee66afaec0 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Fri, 18 Feb 2022 08:03:36 -0500 Subject: [PATCH] Refactor interaction response handling to support files This adds support for file sending and allowed_mentions --- discord/interactions.py | 112 +++++++++++++++----------------- discord/webhook/async_.py | 130 ++++++++++++++++++++++++++++++++++---- 2 files changed, 170 insertions(+), 72 deletions(-) diff --git a/discord/interactions.py b/discord/interactions.py index 35d6c5982..57925f2b3 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -31,6 +31,7 @@ import asyncio from . import utils from .enums import try_enum, InteractionType, InteractionResponseType from .errors import InteractionResponded, HTTPException, ClientException +from .flags import MessageFlags from .channel import PartialMessageable, ChannelType from .user import User @@ -39,7 +40,7 @@ from .message import Message, Attachment from .object import Object from .permissions import Permissions from .http import handle_message_parameters -from .webhook.async_ import async_context, Webhook +from .webhook.async_ import async_context, Webhook, interaction_response_params, interaction_message_response_params __all__ = ( 'Interaction', @@ -421,9 +422,8 @@ class InteractionResponse: if defer_type: adapter = async_context.get() - await adapter.create_interaction_response( - parent.id, parent.token, session=parent._session, type=defer_type, data=data - ) + params = interaction_response_params(type=defer_type, data=data) + await adapter.create_interaction_response(parent.id, parent.token, session=parent._session, params=params) self._responded = True async def pong(self) -> None: @@ -446,9 +446,8 @@ class InteractionResponse: parent = self._parent if parent.type is InteractionType.ping: adapter = async_context.get() - await adapter.create_interaction_response( - parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value - ) + params = interaction_response_params(InteractionResponseType.pong.value) + await adapter.create_interaction_response(parent.id, parent.token, session=parent._session, params=params) self._responded = True async def send_message( @@ -457,9 +456,12 @@ class InteractionResponse: *, embed: Embed = MISSING, embeds: List[Embed] = MISSING, + file: File = MISSING, + files: List[File] = MISSING, view: View = MISSING, tts: bool = False, ephemeral: bool = False, + allowed_mentions: AllowedMentions = MISSING, ) -> None: """|coro| @@ -475,6 +477,10 @@ class InteractionResponse: embed: :class:`Embed` The rich embed for the content to send. This cannot be mixed with ``embeds`` parameter. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. tts: :class:`bool` Indicates if the message should be sent using text-to-speech. view: :class:`discord.ui.View` @@ -483,13 +489,16 @@ class InteractionResponse: Indicates if the message should only be visible to the user who started the interaction. If a view is sent with an ephemeral message and it has no timeout set then the timeout is set to 15 minutes. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for + more information. Raises ------- HTTPException Sending the message failed. TypeError - You specified both ``embed`` and ``embeds``. + You specified both ``embed`` and ``embeds`` or ``file`` and ``files``. ValueError The length of ``embeds`` was invalid. InteractionResponded @@ -498,38 +507,32 @@ class InteractionResponse: if self._responded: raise InteractionResponded(self._parent) - payload: Dict[str, Any] = { - 'tts': tts, - } - - if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix embed and embeds keyword arguments') - - if embed is not MISSING: - embeds = [embed] - - if embeds: - if len(embeds) > 10: - raise ValueError('embeds cannot exceed maximum of 10 elements') - payload['embeds'] = [e.to_dict() for e in embeds] - - if content is not None: - payload['content'] = str(content) - if ephemeral: - payload['flags'] = 64 - - if view is not MISSING: - payload['components'] = view.to_components() + flags = MessageFlags._from_value(64) + else: + flags = MISSING parent = self._parent adapter = async_context.get() + params = interaction_message_response_params( + type=InteractionResponseType.channel_message.value, + content=content, + tts=tts, + embeds=embeds, + embed=embed, + file=file, + files=files, + previous_allowed_mentions=parent._state.allowed_mentions, + allowed_mentions=allowed_mentions, + flags=flags, + view=view, + ) + await adapter.create_interaction_response( parent.id, parent.token, session=parent._session, - type=InteractionResponseType.channel_message.value, - data=payload, + params=params, ) if view is not MISSING: @@ -548,6 +551,7 @@ class InteractionResponse: embeds: List[Embed] = MISSING, attachments: List[Attachment] = MISSING, view: Optional[View] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING, ) -> None: """|coro| @@ -569,6 +573,9 @@ class InteractionResponse: view: Optional[:class:`~discord.ui.View`] The updated view to update this message with. If ``None`` is passed then the view is removed. + allowed_mentions: Optional[:class:`~discord.AllowedMentions`] + Controls the mentions being processed in this message. See :meth:`.Message.edit` + for more information. Raises ------- @@ -589,42 +596,25 @@ class InteractionResponse: if parent.type is not InteractionType.component: return - payload = {} - if content is not MISSING: - if content is None: - payload['content'] = None - else: - payload['content'] = str(content) - - if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix both embed and embeds keyword arguments') - - if embed is not MISSING: - if embed is None: - embeds = [] - else: - embeds = [embed] - - if embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] - - if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] - - if view is not MISSING: + if view is not MISSING and message_id is not None: state.prevent_view_updates_for(message_id) - if view is None: - payload['components'] = [] - else: - payload['components'] = view.to_components() adapter = async_context.get() + params = interaction_message_response_params( + type=InteractionResponseType.message_update.value, + content=content, + embed=embed, + embeds=embeds, + attachments=attachments, + previous_allowed_mentions=parent._state.allowed_mentions, + allowed_mentions=allowed_mentions, + ) + await adapter.create_interaction_response( parent.id, parent.token, session=parent._session, - type=InteractionResponseType.message_update.value, - data=payload, + params=params, ) if view and not view.is_finished(): diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 77bd82e17..0f23527b7 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType from ..user import BaseUser, User from ..flags import MessageFlags from ..asset import Asset -from ..http import Route, handle_message_parameters +from ..http import Route, handle_message_parameters, MultipartParameters from ..mixins import Hashable from ..channel import PartialMessageable @@ -60,6 +60,7 @@ if TYPE_CHECKING: from ..file import File from ..embeds import Embed from ..mentions import AllowedMentions + from ..message import Attachment from ..state import ConnectionState from ..http import Response from ..types.webhook import ( @@ -349,16 +350,8 @@ class AsyncWebhookAdapter: token: str, *, session: aiohttp.ClientSession, - type: int, - data: Optional[Dict[str, Any]] = None, + params: MultipartParameters, ) -> Response[None]: - payload: Dict[str, Any] = { - 'type': type, - } - - if data is not None: - payload['data'] = data - route = Route( 'POST', '/interactions/{webhook_id}/{webhook_token}/callback', @@ -366,7 +359,10 @@ class AsyncWebhookAdapter: webhook_token=token, ) - return self.request(route, session=session, payload=payload) + if params.files: + return self.request(route, session=session, files=params.files, multipart=params.multipart) + else: + return self.request(route, session=session, payload=params.payload) def get_original_interaction_response( self, @@ -417,6 +413,118 @@ class AsyncWebhookAdapter: return self.request(r, session=session) +def interaction_response_params(type: int, data: Optional[Dict[str, Any]] = None) -> MultipartParameters: + payload: Dict[str, Any] = { + 'type': type, + } + if data is not None: + payload['data'] = data + + return MultipartParameters(payload=payload, multipart=None, files=None) + + +# This is a subset of handle_message_parameters +def interaction_message_response_params( + *, + type: int, + content: Optional[str] = MISSING, + tts: bool = False, + flags: MessageFlags = MISSING, + file: File = MISSING, + files: List[File] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: List[Embed] = MISSING, + attachments: List[Attachment] = MISSING, + view: Optional[View] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING, + previous_allowed_mentions: Optional[AllowedMentions] = None, +) -> MultipartParameters: + if files is not MISSING and file is not MISSING: + raise TypeError('Cannot mix file and files keyword arguments.') + if embeds is not MISSING and embed is not MISSING: + raise TypeError('Cannot mix embed and embeds keyword arguments.') + + data: Optional[Dict[str, Any]] = { + 'tts': tts, + } + + if embeds is not MISSING: + if len(embeds) > 10: + raise InvalidArgument('embeds has a maximum of 10 elements.') + data['embeds'] = [e.to_dict() for e in embeds] + + if embed is not MISSING: + if embed is None: + data['embeds'] = [] + else: + data['embeds'] = [embed.to_dict()] + + if content is not MISSING: + if content is not None: + data['content'] = str(content) + else: + data['content'] = None + + if view is not MISSING: + if view is not None: + data['components'] = view.to_components() + else: + data['components'] = [] + + if attachments is not MISSING: + # Note: This will be overwritten if file or files is provided + # However, right now this is only passed via edit not send + data['attachments'] = [a.to_dict() for a in attachments] + + if flags is not MISSING: + data['flags'] = flags.value + + if allowed_mentions: + if previous_allowed_mentions is not None: + data['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict() + else: + data['allowed_mentions'] = allowed_mentions.to_dict() + elif previous_allowed_mentions is not None: + data['allowed_mentions'] = previous_allowed_mentions.to_dict() + + multipart = [] + if file is not MISSING: + files = [file] + + if files: + for index, file in enumerate(files): + attachments_payload = [] + for index, file in enumerate(files): + attachment = { + 'id': index, + 'filename': file.filename, + } + + if file.description is not None: + attachment['description'] = file.description + + attachments_payload.append(attachment) + + data['attachments'] = attachments_payload + + data = {'type': type, 'data': data} + multipart.append({'name': 'payload_json', 'value': utils._to_json(data)}) + data = None + for index, file in enumerate(files): + multipart.append( + { + 'name': f'files[{index}]', + 'value': file.fp, + 'filename': file.filename, + 'content_type': 'application/octet-stream', + } + ) + else: + data = {'type': type, 'data': data} + + return MultipartParameters(payload=data, multipart=multipart, files=files) + + async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())