From 29344b9cce6eb85676895b77ad6a1e658b7aa56c Mon Sep 17 00:00:00 2001 From: Lilly Rose Berner Date: Sun, 10 Dec 2023 17:37:03 +0100 Subject: [PATCH] Add thread getters to Message --- discord/message.py | 79 ++++++++++++++++++++++++++++++++++++++++ discord/types/message.py | 2 + 2 files changed, 81 insertions(+) diff --git a/discord/message.py b/discord/message.py index c40fed6a6..0b29e469c 100644 --- a/discord/message.py +++ b/discord/message.py @@ -804,6 +804,20 @@ class PartialMessage(Hashable): guild_id = getattr(self.guild, 'id', '@me') return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}' + @property + def thread(self) -> Optional[Thread]: + """Optional[:class:`Thread`]: The public thread created from this message, if it exists. + + .. note:: + + This does not retrieve archived threads, as they are not retained in the internal + cache. Use :meth:`fetch_thread` instead. + + .. versionadded:: 2.4 + """ + if self.guild is not None: + return self.guild.get_thread(self.id) + async def fetch(self) -> Message: """|coro| @@ -1280,6 +1294,40 @@ class PartialMessage(Hashable): ) return Thread(guild=self.guild, state=self._state, data=data) + async def fetch_thread(self) -> Thread: + """|coro| + + Retrieves the public thread attached to this message. + + .. note:: + + This method is an API call. For general usage, consider :attr:`thread` instead. + + .. versionadded:: 2.4 + + Raises + ------- + InvalidData + An unknown channel type was received from Discord + or the guild the thread belongs to is not the same + as the one in this object points to. + HTTPException + Retrieving the thread failed. + NotFound + There is no thread attached to this message. + Forbidden + You do not have permission to fetch this channel. + + Returns + -------- + :class:`.Thread` + The public thread attached to this message. + """ + if self.guild is None: + raise ValueError('This message does not have guild info attached.') + + return await self.guild.fetch_channel(self.id) # type: ignore # Can only be Thread in this case + @overload async def reply( self, @@ -1572,6 +1620,7 @@ class Message(PartialMessage, Hashable): '_cs_raw_channel_mentions', '_cs_raw_role_mentions', '_cs_system_content', + '_thread', 'tts', 'content', 'webhook_id', @@ -1640,6 +1689,21 @@ class Message(PartialMessage, Hashable): except AttributeError: self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) + self._thread: Optional[Thread] = None + + if self.guild is not None: + try: + thread = data['thread'] + except KeyError: + pass + else: + self._thread = self.guild.get_thread(int(thread['id'])) + + if self._thread is not None: + self._thread._update(thread) + else: + self._thread = Thread(guild=self.guild, state=state, data=thread) + self.interaction: Optional[MessageInteraction] = None try: @@ -1982,6 +2046,21 @@ class Message(PartialMessage, Hashable): """Optional[:class:`datetime.datetime`]: An aware UTC datetime object containing the edited time of the message.""" return self._edited_timestamp + @property + def thread(self) -> Optional[Thread]: + """Optional[:class:`Thread`]: The public thread created from this message, if it exists. + + .. note:: + + For messages received via the gateway this does not retrieve archived threads, as they + are not retained in the internal cache. Use :meth:`fetch_thread` instead. + + .. versionadded:: 2.4 + """ + if self.guild is not None: + # Fall back to guild threads in case one was created after the message + return self._thread or self.guild.get_thread(self.id) + def is_system(self) -> bool: """:class:`bool`: Whether the message is a system message. diff --git a/discord/types/message.py b/discord/types/message.py index e1046c82a..883f84211 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -36,6 +36,7 @@ from .channel import ChannelType from .components import Component from .interactions import MessageInteraction from .sticker import StickerItem +from .threads import Thread class PartialMessage(TypedDict): @@ -146,6 +147,7 @@ class Message(PartialMessage): components: NotRequired[List[Component]] position: NotRequired[int] role_subscription_data: NotRequired[RoleSubscriptionData] + thread: NotRequired[Thread] AllowedMentionType = Literal['roles', 'users', 'everyone']