Browse Source

Return a named tuple with message from ForumChannel.create_thread

pull/7983/head
Rapptz 3 years ago
parent
commit
7f210c90f4
  1. 23
      discord/channel.py
  2. 2
      discord/http.py
  3. 5
      discord/types/threads.py

23
discord/channel.py

@ -33,6 +33,7 @@ from typing import (
List, List,
Literal, Literal,
Mapping, Mapping,
NamedTuple,
Optional, Optional,
TYPE_CHECKING, TYPE_CHECKING,
Sequence, Sequence,
@ -96,6 +97,11 @@ if TYPE_CHECKING:
from .types.snowflake import SnowflakeList from .types.snowflake import SnowflakeList
class ThreadWithMessage(NamedTuple):
thread: Thread
message: Message
class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""Represents a Discord guild text channel. """Represents a Discord guild text channel.
@ -2159,7 +2165,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
view: View = MISSING, view: View = MISSING,
suppress_embeds: bool = False, suppress_embeds: bool = False,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Thread: ) -> ThreadWithMessage:
"""|coro| """|coro|
Creates a thread in this forum. Creates a thread in this forum.
@ -2222,8 +2228,9 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
Returns Returns
-------- --------
:class:`Thread` Tuple[:class:`Thread`, :class:`Message`]
The created thread The created thread with the created message.
This is also accessible as a namedtuple with ``thread`` and ``message`` fields.
""" """
state = self._state state = self._state
@ -2267,8 +2274,16 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
flags=flags, flags=flags,
channel_payload=channel_payload, channel_payload=channel_payload,
) as params: ) as params:
# Circular import
from .message import Message
data = await state.http.start_thread_in_forum(self.id, params=params, reason=reason) data = await state.http.start_thread_in_forum(self.id, params=params, reason=reason)
return Thread(guild=self.guild, state=self._state, data=data) thread = Thread(guild=self.guild, state=self._state, data=data)
message = Message(state=self._state, channel=thread, data=data['message'])
if view:
self._state.store_view(view, message.id)
return ThreadWithMessage(thread=thread, message=message)
class DMChannel(discord.abc.Messageable, Hashable): class DMChannel(discord.abc.Messageable, Hashable):

2
discord/http.py

@ -990,7 +990,7 @@ class HTTPClient:
*, *,
params: MultipartParameters, params: MultipartParameters,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Response[threads.Thread]: ) -> Response[threads.ForumThread]:
query = {'use_nested_fields': 1} query = {'use_nested_fields': 1}
r = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id) r = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id)
if params.files: if params.files:

5
discord/types/threads.py

@ -28,6 +28,7 @@ from typing import List, Literal, Optional, TypedDict
from typing_extensions import NotRequired from typing_extensions import NotRequired
from .snowflake import Snowflake from .snowflake import Snowflake
from .message import Message
ThreadType = Literal[10, 11, 12] ThreadType = Literal[10, 11, 12]
ThreadArchiveDuration = Literal[60, 1440, 4320, 10080] ThreadArchiveDuration = Literal[60, 1440, 4320, 10080]
@ -72,3 +73,7 @@ class ThreadPaginationPayload(TypedDict):
threads: List[Thread] threads: List[Thread]
members: List[ThreadMember] members: List[ThreadMember]
has_more: bool has_more: bool
class ForumThread(Thread):
message: Message

Loading…
Cancel
Save