diff --git a/discord/channel.py b/discord/channel.py index ffb11ffcb..f6f7d6b2b 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -2038,3 +2038,10 @@ def _threaded_channel_factory(channel_type: int): if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): return Thread, value return cls, value + + +def _threaded_guild_channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): + return Thread, value + return cls, value diff --git a/discord/guild.py b/discord/guild.py index c61ed7f05..ebdc8764e 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -52,6 +52,7 @@ from .colour import Colour from .errors import InvalidArgument, ClientException from .channel import * from .channel import _guild_channel_factory +from .channel import _threaded_guild_channel_factory from .enums import ( AuditLogAction, VideoQualityMode, @@ -1703,14 +1704,14 @@ class Guild(Hashable): data: BanPayload = await self._state.http.get_ban(user.id, self.id) return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason']) - async def fetch_channel(self, channel_id: int, /) -> GuildChannel: + async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: """|coro| - Retrieves a :class:`.abc.GuildChannel` with the specified ID. + Retrieves a :class:`.abc.GuildChannel` or :class:`.Thread` with the specified ID. .. note:: - This method is an API call. For general usage, consider :meth:`get_channel` instead. + This method is an API call. For general usage, consider :meth:`get_channel_or_thread` instead. .. versionadded:: 2.0 @@ -1729,12 +1730,12 @@ class Guild(Hashable): Returns -------- - :class:`.abc.GuildChannel` + Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel from the ID. """ data = await self._state.http.get_channel(channel_id) - factory, ch_type = _guild_channel_factory(data['type']) + factory, ch_type = _threaded_guild_channel_factory(data['type']) if factory is None: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))