From 9c066a8cf699f997f4531e81abaf924feb49c78b Mon Sep 17 00:00:00 2001 From: Rapptz Date: Fri, 18 Feb 2022 06:59:34 -0500 Subject: [PATCH] Refactor internal message sending and editing parameter passing This reduces some repetition in many functions and is ripped out of the webhook code. This also removes the unused HTTP functions for interaction responses since those belong in the webhook code rather than the HTTPClient. --- discord/abc.py | 114 ++-------- discord/http.py | 463 ++++++++++++++------------------------ discord/interactions.py | 3 +- discord/message.py | 71 +++--- discord/webhook/async_.py | 110 +-------- discord/webhook/sync.py | 4 +- 6 files changed, 233 insertions(+), 532 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index fd2dc4bb9..3b7027394 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -51,6 +51,7 @@ from .permissions import PermissionOverwrite, Permissions from .role import Role from .invite import Invite from .file import File +from .http import handle_message_parameters from .voice_client import VoiceClient, VoiceProtocol from .sticker import GuildSticker, StickerItem from . import utils @@ -1330,107 +1331,40 @@ class Messageable: channel = await self._get_channel() state = self._state content = str(content) if content is not None else None - - if embed is not None and embeds is not None: - raise InvalidArgument('cannot pass both embed and embeds parameter to send()') - - if embed is not None: - embed = embed.to_dict() - - elif embeds is not None: - if len(embeds) > 10: - raise InvalidArgument('embeds parameter must be a list of up to 10 elements') - embeds = [embed.to_dict() for embed in embeds] + previous_allowed_mention = state.allowed_mentions if stickers is not None: stickers = [sticker.id for sticker in stickers] - - if allowed_mentions is not None: - if state.allowed_mentions is not None: - allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() - else: - allowed_mentions = allowed_mentions.to_dict() else: - allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() - - if mention_author is not None: - allowed_mentions = allowed_mentions or AllowedMentions().to_dict() - allowed_mentions['replied_user'] = bool(mention_author) + stickers = MISSING if reference is not None: try: reference = reference.to_message_reference_dict() except AttributeError: raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None - - if view: - if not hasattr(view, '__discord_ui_view__'): - raise InvalidArgument(f'view parameter must be View not {view.__class__!r}') - - components = view.to_components() - else: - components = None - - if file is not None and files is not None: - raise InvalidArgument('cannot pass both file and files parameter to send()') - - if file is not None: - if not isinstance(file, File): - raise InvalidArgument('file parameter must be File') - - try: - data = await state.http.send_files( - channel.id, - files=[file], - allowed_mentions=allowed_mentions, - content=content, - tts=tts, - embed=embed, - embeds=embeds, - nonce=nonce, - message_reference=reference, - stickers=stickers, - components=components, - ) - finally: - file.close() - - elif files is not None: - if len(files) > 10: - raise InvalidArgument('files parameter must be a list of up to 10 elements') - elif not all(isinstance(file, File) for file in files): - raise InvalidArgument('files parameter must be a list of File') - - try: - data = await state.http.send_files( - channel.id, - files=files, - content=content, - tts=tts, - embed=embed, - embeds=embeds, - nonce=nonce, - allowed_mentions=allowed_mentions, - message_reference=reference, - stickers=stickers, - components=components, - ) - finally: - for f in files: - f.close() else: - data = await state.http.send_message( - channel.id, - content, - tts=tts, - embed=embed, - embeds=embeds, - nonce=nonce, - allowed_mentions=allowed_mentions, - message_reference=reference, - stickers=stickers, - components=components, - ) + reference = MISSING + + if view and not hasattr(view, '__discord_ui_view__'): + raise InvalidArgument(f'view parameter must be View not {view.__class__!r}') + + with handle_message_parameters( + content=content, + tts=tts, + file=file if file is not None else MISSING, + files=files if files is not None else MISSING, + embed=embed if embed is not None else MISSING, + embeds=embeds if embeds is not None else MISSING, + nonce=nonce, + allowed_mentions=allowed_mentions, + message_reference=reference, + previous_allowed_mentions=previous_allowed_mention, + mention_author=mention_author, + stickers=stickers, + view=view, + ) as params: + data = await state.http.send_message(channel.id, params=params) ret = state.create_message(channel=channel, data=data) if view: diff --git a/discord/http.py b/discord/http.py index 04c90c89f..f38ca593a 100644 --- a/discord/http.py +++ b/discord/http.py @@ -35,14 +35,15 @@ from typing import ( Iterable, List, Literal, + NamedTuple, Optional, + overload, Sequence, - TYPE_CHECKING, Tuple, + TYPE_CHECKING, Type, TypeVar, Union, - overload, ) from urllib.parse import quote as _uriquote import weakref @@ -58,6 +59,11 @@ _log = logging.getLogger(__name__) if TYPE_CHECKING: from .file import File + from .ui.view import View + from .embeds import Embed + from .mentions import AllowedMentions + from .message import Attachment + from .flags import MessageFlags from .enums import ( AuditLogAction, InteractionResponseType, @@ -110,6 +116,149 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any] return text +class MultipartParameters(NamedTuple): + payload: Optional[Dict[str, Any]] + multipart: Optional[List[Dict[str, Any]]] + files: Optional[List[File]] + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: Optional[Type[BE]], + exc: Optional[BE], + traceback: Optional[TracebackType], + ) -> None: + if self.files: + for file in self.files: + file.close() + + +def handle_message_parameters( + content: Optional[str] = MISSING, + *, + username: str = MISSING, + avatar_url: Any = MISSING, + tts: bool = False, + nonce: Optional[Union[int, str]] = None, + flags: MessageFlags = MISSING, + file: File = MISSING, + files: List[File] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: List[Embed] = MISSING, + attachments: List[Attachment] = MISSING, + view: Optional[View] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING, + message_reference: Optional[message.MessageReference] = MISSING, + stickers: Optional[SnowflakeList] = MISSING, + previous_allowed_mentions: Optional[AllowedMentions] = None, + mention_author: Optional[bool] = None, +) -> MultipartParameters: + if files is not MISSING and file is not MISSING: + raise TypeError('Cannot mix file and files keyword arguments.') + if embeds is not MISSING and embed is not MISSING: + raise TypeError('Cannot mix embed and embeds keyword arguments.') + + payload = {} + if embeds is not MISSING: + if len(embeds) > 10: + raise InvalidArgument('embeds has a maximum of 10 elements.') + payload['embeds'] = [e.to_dict() for e in embeds] + + if embed is not MISSING: + if embed is None: + payload['embeds'] = [] + else: + payload['embeds'] = [embed.to_dict()] + + if content is not MISSING: + if content is not None: + payload['content'] = str(content) + else: + payload['content'] = None + + if view is not MISSING: + if view is not None: + payload['components'] = view.to_components() + else: + payload['components'] = [] + + if nonce is not MISSING: + payload['nonce'] = str(nonce) + + if message_reference is not MISSING: + payload['message_reference'] = message_reference + + if attachments is not MISSING: + # Note: This will be overwritten if file or files is provided + # However, right now this is only passed via Message.edit not Messageable.send + payload['attachments'] = [a.to_dict() for a in attachments] + + if stickers is not MISSING: + if stickers is not None: + payload['sticker_ids'] = stickers + else: + payload['sticker_ids'] = [] + + payload['tts'] = tts + if avatar_url: + payload['avatar_url'] = str(avatar_url) + if username: + payload['username'] = username + + if flags is not MISSING: + payload['flags'] = flags.value + + if allowed_mentions: + if previous_allowed_mentions is not None: + payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict() + else: + payload['allowed_mentions'] = allowed_mentions.to_dict() + elif previous_allowed_mentions is not None: + payload['allowed_mentions'] = previous_allowed_mentions.to_dict() + + if mention_author is not None: + try: + payload['allowed_mentions']['replied_user'] = mention_author + except KeyError: + pass + + multipart = [] + if file is not MISSING: + files = [file] + + if files: + for index, file in enumerate(files): + attachments_payload = [] + for index, file in enumerate(files): + attachment = { + 'id': index, + 'filename': file.filename, + } + + if file.description is not None: + attachment['description'] = file.description + + attachments_payload.append(attachment) + + payload['attachments'] = attachments_payload + + multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)}) + payload = None + for index, file in enumerate(files): + multipart.append( + { + 'name': f'files[{index}]', + 'value': file.fp, + 'filename': file.filename, + 'content_type': 'application/octet-stream', + } + ) + + return MultipartParameters(payload=payload, multipart=multipart, files=files) + + class Route: BASE: ClassVar[str] = 'https://discord.com/api/v8' @@ -268,7 +417,7 @@ class HTTPClient: if form: # with quote_fields=True '[' and ']' in file field names are escaped, which discord does not support - form_data = aiohttp.FormData(quote_fields=False) + form_data = aiohttp.FormData(quote_fields=False) for params in form: form_data.add_field(**params) kwargs['data'] = form_data @@ -417,144 +566,18 @@ class HTTPClient: def send_message( self, channel_id: Snowflake, - content: Optional[str], *, - tts: bool = False, - embed: Optional[embed.Embed] = None, - embeds: Optional[List[embed.Embed]] = None, - nonce: Optional[str] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - message_reference: Optional[message.MessageReference] = None, - stickers: Optional[List[sticker.StickerItem]] = None, - components: Optional[List[components.Component]] = None, + params: MultipartParameters, ) -> Response[message.Message]: r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) - payload = {} - - if content: - payload['content'] = content - - if tts: - payload['tts'] = True - - if embed: - payload['embeds'] = [embed] - - if embeds: - payload['embeds'] = embeds - - if nonce: - payload['nonce'] = nonce - - if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions - - if message_reference: - payload['message_reference'] = message_reference - - if components: - payload['components'] = components - - if stickers: - payload['sticker_ids'] = stickers - - return self.request(r, json=payload) + if params.files: + return self.request(r, files=params.files, form=params.multipart) + else: + return self.request(r, json=params.payload) def send_typing(self, channel_id: Snowflake) -> Response[None]: return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id)) - def send_multipart_helper( - self, - route: Route, - *, - files: Sequence[File], - content: Optional[str] = None, - tts: bool = False, - embed: Optional[embed.Embed] = None, - embeds: Optional[Iterable[Optional[embed.Embed]]] = None, - nonce: Optional[str] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - message_reference: Optional[message.MessageReference] = None, - stickers: Optional[List[sticker.StickerItem]] = None, - components: Optional[List[components.Component]] = None, - ) -> Response[message.Message]: - form = [] - - payload: Dict[str, Any] = {'tts': tts} - if content: - payload['content'] = content - if embed: - payload['embeds'] = [embed] - if embeds: - payload['embeds'] = embeds - if nonce: - payload['nonce'] = nonce - if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions - if message_reference: - payload['message_reference'] = message_reference - if components: - payload['components'] = components - if stickers: - payload['sticker_ids'] = stickers - if files: - attachments = [] - for index, file in enumerate(files): - attachment = { - "id": index, - "filename": file.filename, - } - - if file.description is not None: - attachment["description"] = file.description - - attachments.append(attachment) - - payload['attachments'] = attachments - - form.append({'name': 'payload_json', 'value': utils._to_json(payload)}) - for index, file in enumerate(files): - form.append( - { - 'name': f'files[{index}]', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'image/png', - } - ) - - return self.request(route, form=form, files=files) - - def send_files( - self, - channel_id: Snowflake, - *, - files: Sequence[File], - content: Optional[str] = None, - tts: bool = False, - embed: Optional[embed.Embed] = None, - embeds: Optional[List[embed.Embed]] = None, - nonce: Optional[str] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - message_reference: Optional[message.MessageReference] = None, - stickers: Optional[List[sticker.StickerItem]] = None, - components: Optional[List[components.Component]] = None, - ) -> Response[message.Message]: - r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) - return self.send_multipart_helper( - r, - files=files, - content=content, - tts=tts, - embed=embed, - embeds=embeds, - nonce=nonce, - allowed_mentions=allowed_mentions, - message_reference=message_reference, - stickers=stickers, - components=components, - ) - def delete_message( self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None ) -> Response[None]: @@ -571,9 +594,9 @@ class HTTPClient: return self.request(r, json=payload, reason=reason) - def edit_message(self, channel_id: Snowflake, message_id: Snowflake, **fields: Any) -> Response[message.Message]: + def edit_message(self, channel_id: Snowflake, message_id: Snowflake, *, params: MultipartParameters) -> Response[message.Message]: r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) - return self.request(r, json=fields) + return self.request(r, json=params.payload) def add_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( @@ -1241,7 +1264,11 @@ class HTTPClient: ) def modify_guild_sticker( - self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str], + self, + guild_id: Snowflake, + sticker_id: Snowflake, + payload: sticker.EditGuildSticker, + reason: Optional[str], ) -> Response[sticker.GuildSticker]: return self.request( Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), @@ -1706,9 +1733,7 @@ class HTTPClient: def get_global_commands(self, application_id: Snowflake) -> Response[List[command.ApplicationCommand]]: return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id)) - def get_global_command( - self, application_id: Snowflake, command_id: Snowflake - ) -> Response[command.ApplicationCommand]: + def get_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[command.ApplicationCommand]: r = Route( 'GET', '/applications/{application_id}/commands/{command_id}', @@ -1750,9 +1775,7 @@ class HTTPClient: ) return self.request(r) - def bulk_upsert_global_commands( - self, application_id: Snowflake, payload - ) -> Response[List[command.ApplicationCommand]]: + def bulk_upsert_global_commands(self, application_id: Snowflake, payload) -> Response[List[command.ApplicationCommand]]: r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id) return self.request(r, json=payload) @@ -1849,160 +1872,6 @@ class HTTPClient: ) return self.request(r, json=payload) - # Interaction responses - - def _edit_webhook_helper( - self, - route: Route, - file: Optional[File] = None, - content: Optional[str] = None, - embeds: Optional[List[embed.Embed]] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - ): - - payload: Dict[str, Any] = {} - if content: - payload['content'] = content - if embeds: - payload['embeds'] = embeds - if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions - - form: List[Dict[str, Any]] = [ - { - 'name': 'payload_json', - 'value': utils._to_json(payload), - } - ] - - if file: - form.append( - { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', - } - ) - - return self.request(route, form=form, files=[file] if file else None) - - def create_interaction_response( - self, - interaction_id: Snowflake, - token: str, - *, - type: InteractionResponseType, - data: Optional[Dict[str, Any]] = None, - ) -> Response[None]: - r = Route( - 'POST', - '/interactions/{interaction_id}/{interaction_token}/callback', - interaction_id=interaction_id, - interaction_token=token, - ) - payload: Dict[str, Any] = { - 'type': type, - } - - if data is not None: - payload['data'] = data - - return self.request(r, json=payload) - - def get_original_interaction_response( - self, - application_id: Snowflake, - token: str, - ) -> Response[message.Message]: - r = Route( - 'GET', - '/webhooks/{application_id}/{interaction_token}/messages/@original', - application_id=application_id, - interaction_token=token, - ) - return self.request(r) - - def edit_original_interaction_response( - self, - application_id: Snowflake, - token: str, - file: Optional[File] = None, - content: Optional[str] = None, - embeds: Optional[List[embed.Embed]] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - ) -> Response[message.Message]: - r = Route( - 'PATCH', - '/webhooks/{application_id}/{interaction_token}/messages/@original', - application_id=application_id, - interaction_token=token, - ) - return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions) - - def delete_original_interaction_response(self, application_id: Snowflake, token: str) -> Response[None]: - r = Route( - 'DELETE', - '/webhooks/{application_id}/{interaction_token}/messages/@original', - application_id=application_id, - interaction_token=token, - ) - return self.request(r) - - def create_followup_message( - self, - application_id: Snowflake, - token: str, - files: List[File] = [], - content: Optional[str] = None, - tts: bool = False, - embeds: Optional[List[embed.Embed]] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - ) -> Response[message.Message]: - r = Route( - 'POST', - '/webhooks/{application_id}/{interaction_token}', - application_id=application_id, - interaction_token=token, - ) - return self.send_multipart_helper( - r, - content=content, - files=files, - tts=tts, - embeds=embeds, - allowed_mentions=allowed_mentions, - ) - - def edit_followup_message( - self, - application_id: Snowflake, - token: str, - message_id: Snowflake, - file: Optional[File] = None, - content: Optional[str] = None, - embeds: Optional[List[embed.Embed]] = None, - allowed_mentions: Optional[message.AllowedMentions] = None, - ) -> Response[message.Message]: - r = Route( - 'PATCH', - '/webhooks/{application_id}/{interaction_token}/messages/{message_id}', - application_id=application_id, - interaction_token=token, - message_id=message_id, - ) - return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions) - - def delete_followup_message(self, application_id: Snowflake, token: str, message_id: Snowflake) -> Response[None]: - r = Route( - 'DELETE', - '/webhooks/{application_id}/{interaction_token}/messages/{message_id}', - application_id=application_id, - interaction_token=token, - message_id=message_id, - ) - return self.request(r) - def get_guild_application_command_permissions( self, application_id: Snowflake, diff --git a/discord/interactions.py b/discord/interactions.py index b89d49f53..35d6c5982 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -38,7 +38,8 @@ from .member import Member from .message import Message, Attachment from .object import Object from .permissions import Permissions -from .webhook.async_ import async_context, Webhook, handle_message_parameters +from .http import handle_message_parameters +from .webhook.async_ import async_context, Webhook __all__ = ( 'Interaction', diff --git a/discord/message.py b/discord/message.py index c65d714b3..f56c5c45b 100644 --- a/discord/message.py +++ b/discord/message.py @@ -29,7 +29,21 @@ import datetime import re import io from os import PathLike -from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type +from typing import ( + Dict, + TYPE_CHECKING, + Union, + List, + Optional, + Any, + Callable, + Tuple, + ClassVar, + Optional, + overload, + TypeVar, + Type, +) from . import utils from .reaction import Reaction @@ -43,6 +57,7 @@ from .member import Member from .flags import MessageFlags from .file import File from .utils import escape_mentions, MISSING +from .http import handle_message_parameters from .guild import Guild from .mixins import Hashable from .sticker import StickerItem @@ -350,7 +365,7 @@ class DeletedReferencedMessage: def id(self) -> int: """:class:`int`: The message ID of the deleted referenced message.""" # the parent's message id won't be None here - return self._parent.message_id # type: ignore + return self._parent.message_id # type: ignore @property def channel_id(self) -> int: @@ -1217,6 +1232,8 @@ class Message(Hashable): attachments: List[:class:`Attachment`] A list of attachments to keep in the message. If ``[]`` is passed then all attachments are removed. + + .. versionadded:: 2.0 suppress: :class:`bool` Whether to suppress embeds for the message. This removes all the embeds if set to ``True``. If set to ``False`` @@ -1250,50 +1267,26 @@ class Message(Hashable): You specified both ``embed`` and ``embeds`` """ - payload: Dict[str, Any] = {} - if content is not MISSING: - if content is not None: - payload['content'] = str(content) - else: - payload['content'] = None - - if embed is not MISSING and embeds is not MISSING: - raise InvalidArgument('cannot pass both embed and embeds parameter to edit()') - - if embed is not MISSING: - if embed is None: - payload['embeds'] = [] - else: - payload['embeds'] = [embed.to_dict()] - elif embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] - + previous_allowed_mentions = self._state.allowed_mentions if suppress is not MISSING: flags = MessageFlags._from_value(self.flags.value) - flags.suppress_embeds = suppress - payload['flags'] = flags.value - - if allowed_mentions is MISSING: - if self._state.allowed_mentions is not None and self.author.id == self._state.self_id: - payload['allowed_mentions'] = self._state.allowed_mentions.to_dict() else: - if allowed_mentions is not None: - if self._state.allowed_mentions is not None: - payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict() - else: - payload['allowed_mentions'] = allowed_mentions.to_dict() - - if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] + flags = MISSING if view is not MISSING: self._state.prevent_view_updates_for(self.id) - if view: - payload['components'] = view.to_components() - else: - payload['components'] = [] - data = await self._state.http.edit_message(self.channel.id, self.id, **payload) + params = handle_message_parameters( + content=content, + flags=flags, + embed=embed, + embeds=embeds, + attachments=attachments, + view=view, + allowed_mentions=allowed_mentions, + previous_allowed_mentions=previous_allowed_mentions, + ) + data = await self._state.http.edit_message(self.channel.id, self.id, params=params) message = Message(state=self._state, channel=self.channel, data=data) if view and not view.is_finished(): diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index be58a93ac..77bd82e17 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -41,8 +41,9 @@ from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, Discor from ..message import Message from ..enums import try_enum, WebhookType from ..user import BaseUser, User +from ..flags import MessageFlags from ..asset import Asset -from ..http import Route +from ..http import Route, handle_message_parameters from ..mixins import Hashable from ..channel import PartialMessageable @@ -416,107 +417,6 @@ class AsyncWebhookAdapter: return self.request(r, session=session) -class ExecuteWebhookParameters(NamedTuple): - payload: Optional[Dict[str, Any]] - multipart: Optional[List[Dict[str, Any]]] - files: Optional[List[File]] - - -def handle_message_parameters( - content: Optional[str] = MISSING, - *, - username: str = MISSING, - avatar_url: Any = MISSING, - tts: bool = False, - ephemeral: bool = False, - file: File = MISSING, - files: List[File] = MISSING, - embed: Optional[Embed] = MISSING, - embeds: List[Embed] = MISSING, - view: Optional[View] = MISSING, - allowed_mentions: Optional[AllowedMentions] = MISSING, - previous_allowed_mentions: Optional[AllowedMentions] = None, -) -> ExecuteWebhookParameters: - if files is not MISSING and file is not MISSING: - raise TypeError('Cannot mix file and files keyword arguments.') - if embeds is not MISSING and embed is not MISSING: - raise TypeError('Cannot mix embed and embeds keyword arguments.') - - payload = {} - if embeds is not MISSING: - if len(embeds) > 10: - raise InvalidArgument('embeds has a maximum of 10 elements.') - payload['embeds'] = [e.to_dict() for e in embeds] - - if embed is not MISSING: - if embed is None: - payload['embeds'] = [] - else: - payload['embeds'] = [embed.to_dict()] - - if content is not MISSING: - if content is not None: - payload['content'] = str(content) - else: - payload['content'] = None - - if view is not MISSING: - if view is not None: - payload['components'] = view.to_components() - else: - payload['components'] = [] - - payload['tts'] = tts - if avatar_url: - payload['avatar_url'] = str(avatar_url) - if username: - payload['username'] = username - if ephemeral: - payload['flags'] = 64 - - if allowed_mentions: - if previous_allowed_mentions is not None: - payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict() - else: - payload['allowed_mentions'] = allowed_mentions.to_dict() - elif previous_allowed_mentions is not None: - payload['allowed_mentions'] = previous_allowed_mentions.to_dict() - - multipart = [] - if file is not MISSING: - files = [file] - - if files: - for index, file in enumerate(files): - attachments = [] - for index, file in enumerate(files): - attachment = { - "id": index, - "filename": file.filename, - } - - if file.description is not None: - attachment["description"] = file.description - - attachments.append(attachment) - - payload['attachments'] = attachments - - multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)}) - payload = None - for index, file in enumerate(files): - multipart.append( - { - 'name': f'files[{index}]', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', - } - ) - - return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files) - - async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter()) @@ -1356,6 +1256,10 @@ class Webhook(BaseWebhook): previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None) if content is None: content = MISSING + if ephemeral: + flags = MessageFlags._from_value(64) + else: + flags = MISSING application_webhook = self.type is WebhookType.application if ephemeral and not application_webhook: @@ -1379,7 +1283,7 @@ class Webhook(BaseWebhook): files=files, embed=embed, embeds=embeds, - ephemeral=ephemeral, + flags=flags, view=view, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index 3f7cd6018..bf79c2c1b 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -43,10 +43,10 @@ import weakref from .. import utils from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError from ..message import Message -from ..http import Route +from ..http import Route, handle_message_parameters from ..channel import PartialMessageable -from .async_ import BaseWebhook, handle_message_parameters, _WebhookState +from .async_ import BaseWebhook, _WebhookState __all__ = ( 'SyncWebhook',