diff --git a/discord/channel.py b/discord/channel.py index 5a213fe1c..dbe8533fd 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -691,7 +691,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): type=ChannelType.public_thread.value, ) - return Thread(guild=self.guild, data=data) + return Thread(guild=self.guild, state=self._state, data=data) def archived_threads( self, @@ -753,7 +753,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ data = await self._state.http.get_active_threads(self.id) # TODO: thread members? - return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])] + return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])] class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): @@ -1924,3 +1924,9 @@ def _channel_factory(channel_type: Union[ChannelType, int]): return GroupChannel, value else: return cls, value + +def _threaded_channel_factory(channel_type: Union[ChannelType, int]): + cls, value = _channel_factory(channel_type) + if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): + return Thread, value + return cls, value diff --git a/discord/client.py b/discord/client.py index 2b3c3e17c..24ceb31b7 100644 --- a/discord/client.py +++ b/discord/client.py @@ -39,7 +39,7 @@ from .template import Template from .widget import Widget from .guild import Guild from .emoji import Emoji -from .channel import _channel_factory +from .channel import _threaded_channel_factory from .enums import ChannelType from .mentions import AllowedMentions from .errors import * @@ -58,6 +58,7 @@ from .iterators import GuildIterator from .appinfo import AppInfo from .ui.view import View from .stage_instance import StageInstance +from .threads import Thread if TYPE_CHECKING: from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake @@ -1371,10 +1372,10 @@ class Client: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]: + async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]: """|coro| - Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID. + Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. .. note:: @@ -1395,12 +1396,12 @@ class Client: Returns -------- - Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`] + Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`] The channel from the ID. """ data = await self.http.get_channel(channel_id) - factory, ch_type = _channel_factory(data['type']) + factory, ch_type = _threaded_channel_factory(data['type']) if factory is None: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) diff --git a/discord/guild.py b/discord/guild.py index 54b86b29b..1cf8eef04 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -287,7 +287,7 @@ class Guild(Hashable): self._members[member.id] = member def _store_thread(self, payload: ThreadPayload, /) -> Thread: - thread = Thread(guild=self, data=payload) + thread = Thread(guild=self, state=self._state, data=payload) self._threads[thread.id] = thread return thread @@ -466,7 +466,7 @@ class Guild(Hashable): if 'threads' in data: threads = data['threads'] for thread in threads: - self._add_thread(Thread(guild=self, data=thread)) + self._add_thread(Thread(guild=self, state=self._state, data=thread)) @property def channels(self) -> List[GuildChannel]: diff --git a/discord/iterators.py b/discord/iterators.py index 2f272b70e..f725d527e 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -750,4 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): def create_thread(self, data: ThreadPayload) -> Thread: from .threads import Thread - return Thread(guild=self.guild, data=data) + return Thread(guild=self.guild, state=self.guild._state, data=data) diff --git a/discord/message.py b/discord/message.py index 825bda4db..b4604b383 100644 --- a/discord/message.py +++ b/discord/message.py @@ -1491,7 +1491,7 @@ class Message(Hashable): auto_archive_duration=auto_archive_duration, type=ChannelType.public_thread.value, ) - return Thread(guild=self.guild, data=data) # type: ignore + return Thread(guild=self.guild, state=self._state, data=data) # type: ignore async def reply(self, content: Optional[str] = None, **kwargs) -> Message: """|coro| diff --git a/discord/state.py b/discord/state.py index 5daa583e3..f4a6a664c 100644 --- a/discord/state.py +++ b/discord/state.py @@ -715,7 +715,7 @@ class ConnectionState: log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) return - thread = Thread(guild=guild, data=data) + thread = Thread(guild=guild, state=guild._state, data=data) has_thread = guild.get_thread(thread.id) guild._add_thread(thread) if not has_thread: @@ -735,7 +735,7 @@ class ConnectionState: thread._update(data) self.dispatch('thread_update', old, thread) else: - thread = Thread(guild=guild, data=data) + thread = Thread(guild=guild, state=guild._state, data=data) guild._add_thread(thread) self.dispatch('thread_join', thread) diff --git a/discord/threads.py b/discord/threads.py index 85a370185..24eda6512 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -139,8 +139,8 @@ class Thread(Messageable, Hashable): 'archive_timestamp', ) - def __init__(self, *, guild: Guild, data: ThreadPayload): - self._state: ConnectionState = guild._state + def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): + self._state: ConnectionState = state self.guild = guild self._members: Dict[int, ThreadMember] = {} self._from_data(data)