diff --git a/discord/channel.py b/discord/channel.py index f6f7d6b2b..ea7179b90 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -778,29 +778,6 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before) - async def active_threads(self) -> List[Thread]: - """|coro| - - Returns a list of active :class:`Thread` that the client can access. - - This includes both private and public threads. - - .. versionadded:: 2.0 - - Raises - ------ - HTTPException - The request to get the active threads failed. - - Returns - -------- - List[:class:`Thread`] - The archived threads - """ - data = await self._state.http.get_active_threads(self.id) - # TODO: thread members? - return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])] - class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): __slots__ = ( diff --git a/discord/guild.py b/discord/guild.py index ebdc8764e..dabea5d05 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -73,7 +73,7 @@ from .asset import Asset from .flags import SystemChannelFlags from .integrations import Integration, _integration_factory from .stage_instance import StageInstance -from .threads import Thread +from .threads import Thread, ThreadMember from .sticker import GuildSticker from .file import File @@ -423,7 +423,9 @@ class Guild(Hashable): self.mfa_level: MFALevel = guild.get('mfa_level') self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', []))) - self.stickers: Tuple[GuildSticker, ...] = tuple(map(lambda d: state.store_sticker(self, d), guild.get('stickers', []))) + self.stickers: Tuple[GuildSticker, ...] = tuple( + map(lambda d: state.store_sticker(self, d), guild.get('stickers', [])) + ) self.features: List[GuildFeature] = guild.get('features', []) self._splash: Optional[str] = guild.get('splash') self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id') @@ -628,7 +630,6 @@ class Guild(Hashable): """ return self._channels.get(channel_id) or self._threads.get(channel_id) - def get_channel(self, channel_id: int, /) -> Optional[GuildChannel]: """Returns a channel with the given ID. @@ -1591,6 +1592,35 @@ class Guild(Hashable): return [convert(d) for d in data] + async def active_threads(self) -> List[Thread]: + """|coro| + + Returns a list of active :class:`Thread` that the client can access. + + This includes both private and public threads. + + .. versionadded:: 2.0 + + Raises + ------ + HTTPException + The request to get the active threads failed. + + Returns + -------- + List[:class:`Thread`] + The active threads + """ + data = await self._state.http.get_active_threads(self.id) + threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])] + thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads} + for member in data.get('members', []): + thread = thread_lookup.get(int(member['id'])) + if thread is not None: + thread._add_member(ThreadMember(parent=thread, data=member)) + + return threads + # TODO: Remove Optional typing here when async iterators are refactored def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this, diff --git a/discord/http.py b/discord/http.py index b186782ff..868910bae 100644 --- a/discord/http.py +++ b/discord/http.py @@ -958,8 +958,8 @@ class HTTPClient: params['limit'] = limit return self.request(route, params=params) - def get_active_threads(self, channel_id: Snowflake) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/threads/active', channel_id=channel_id) + def get_active_threads(self, guild_id: Snowflake) -> Response[threads.ThreadPaginationPayload]: + route = Route('GET', '/guilds/{guild_id}/threads/active', guild_id=guild_id) return self.request(route) def get_thread_members(self, channel_id: Snowflake) -> Response[List[threads.ThreadMember]]: