From 638b56f474d75e6ae5a69f645aba8908c7ee9b4e Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 1 May 2022 19:32:11 -0400 Subject: [PATCH] Return a named tuple with message from ForumChannel.create_thread --- discord/channel.py | 23 +++++++++++++++++++---- discord/http.py | 2 +- discord/types/threads.py | 5 +++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/discord/channel.py b/discord/channel.py index 0ead33073..6514396f7 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -32,6 +32,7 @@ from typing import ( Iterable, List, Mapping, + NamedTuple, Optional, TYPE_CHECKING, Sequence, @@ -98,6 +99,11 @@ if TYPE_CHECKING: from .types.snowflake import SnowflakeList +class ThreadWithMessage(NamedTuple): + thread: Thread + message: Message + + class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. @@ -2109,7 +2115,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): mention_author: bool = MISSING, suppress_embeds: bool = False, reason: Optional[str] = None, - ) -> Thread: + ) -> ThreadWithMessage: """|coro| Creates a thread in this forum. @@ -2166,8 +2172,9 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): Returns -------- - :class:`Thread` - The created thread + Tuple[:class:`Thread`, :class:`Message`] + The created thread with the created message. + This is also accessible as a namedtuple with ``thread`` and ``message`` fields. """ state = self._state @@ -2207,8 +2214,16 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): flags=flags, channel_payload=channel_payload, ) as params: + # Circular import + from .message import Message + data = await state.http.start_thread_in_forum(self.id, params=params, reason=reason) - return Thread(guild=self.guild, state=self._state, data=data) + thread = Thread(guild=self.guild, state=self._state, data=data) + message = Message(state=self._state, channel=thread, data=data['message']) + if view: + self._state.store_view(view, message.id) + + return ThreadWithMessage(thread=thread, message=message) class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable): diff --git a/discord/http.py b/discord/http.py index d50add7ab..ccf43162a 100644 --- a/discord/http.py +++ b/discord/http.py @@ -1228,7 +1228,7 @@ class HTTPClient: *, params: MultipartParameters, reason: Optional[str] = None, - ) -> Response[threads.Thread]: + ) -> Response[threads.ForumThread]: query = {'use_nested_fields': 1} r = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id) if params.files: diff --git a/discord/types/threads.py b/discord/types/threads.py index 72fae1c3b..de5636c69 100644 --- a/discord/types/threads.py +++ b/discord/types/threads.py @@ -28,6 +28,7 @@ from typing import List, Literal, Optional, TypedDict from typing_extensions import NotRequired from .snowflake import Snowflake +from .message import Message ThreadType = Literal[10, 11, 12] ThreadArchiveDuration = Literal[60, 1440, 4320, 10080] @@ -73,3 +74,7 @@ class ThreadPaginationPayload(TypedDict): threads: List[Thread] members: List[ThreadMember] has_more: bool + + +class ForumThread(Thread): + message: Message