From 7f210c90f49fe72b17d76e0be454a333cbc79f76 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 1 May 2022 19:32:11 -0400 Subject: [PATCH] Return a named tuple with message from ForumChannel.create_thread --- discord/channel.py | 23 +++++++++++++++++++---- discord/http.py | 2 +- discord/types/threads.py | 5 +++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/discord/channel.py b/discord/channel.py index 74ff2bd61..b6a828435 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -33,6 +33,7 @@ from typing import ( List, Literal, Mapping, + NamedTuple, Optional, TYPE_CHECKING, Sequence, @@ -96,6 +97,11 @@ if TYPE_CHECKING: from .types.snowflake import SnowflakeList +class ThreadWithMessage(NamedTuple): + thread: Thread + message: Message + + class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. @@ -2159,7 +2165,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): view: View = MISSING, suppress_embeds: bool = False, reason: Optional[str] = None, - ) -> Thread: + ) -> ThreadWithMessage: """|coro| Creates a thread in this forum. @@ -2222,8 +2228,9 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): Returns -------- - :class:`Thread` - The created thread + Tuple[:class:`Thread`, :class:`Message`] + The created thread with the created message. + This is also accessible as a namedtuple with ``thread`` and ``message`` fields. """ state = self._state @@ -2267,8 +2274,16 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): flags=flags, channel_payload=channel_payload, ) as params: + # Circular import + from .message import Message + data = await state.http.start_thread_in_forum(self.id, params=params, reason=reason) - return Thread(guild=self.guild, state=self._state, data=data) + thread = Thread(guild=self.guild, state=self._state, data=data) + message = Message(state=self._state, channel=thread, data=data['message']) + if view: + self._state.store_view(view, message.id) + + return ThreadWithMessage(thread=thread, message=message) class DMChannel(discord.abc.Messageable, Hashable): diff --git a/discord/http.py b/discord/http.py index 226e2dda9..c06865f9f 100644 --- a/discord/http.py +++ b/discord/http.py @@ -990,7 +990,7 @@ class HTTPClient: *, params: MultipartParameters, reason: Optional[str] = None, - ) -> Response[threads.Thread]: + ) -> Response[threads.ForumThread]: query = {'use_nested_fields': 1} r = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id) if params.files: diff --git a/discord/types/threads.py b/discord/types/threads.py index a0471c0df..bd8e24334 100644 --- a/discord/types/threads.py +++ b/discord/types/threads.py @@ -28,6 +28,7 @@ from typing import List, Literal, Optional, TypedDict from typing_extensions import NotRequired from .snowflake import Snowflake +from .message import Message ThreadType = Literal[10, 11, 12] ThreadArchiveDuration = Literal[60, 1440, 4320, 10080] @@ -72,3 +73,7 @@ class ThreadPaginationPayload(TypedDict): threads: List[Thread] members: List[ThreadMember] has_more: bool + + +class ForumThread(Thread): + message: Message