From c1ce3b949fe2dbc3d34a86d3a7973a7ed60566df Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 14 Apr 2021 09:42:38 -0400 Subject: [PATCH] Implement remaining HTTP endpoints on threads I'm not sure if I missed any -- but this is the entire documented set so far. --- discord/channel.py | 81 ++++++++++++++++++++++++- discord/http.py | 16 +++-- discord/iterators.py | 97 +++++++++++++++++++++++++++++- discord/message.py | 43 ++++++++++++- discord/threads.py | 139 ++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 365 insertions(+), 11 deletions(-) diff --git a/discord/channel.py b/discord/channel.py index 5179c8c4b..0e28f52b0 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -27,6 +27,7 @@ from __future__ import annotations import time import asyncio from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union, overload +import datetime import discord.abc from .permissions import PermissionOverwrite, Permissions @@ -36,6 +37,8 @@ from . import utils from .asset import Asset from .errors import ClientException, NoMoreItems, InvalidArgument from .stage_instance import StageInstance +from .threads import Thread +from .iterators import ArchivedThreadIterator __all__ = ( 'TextChannel', @@ -49,12 +52,12 @@ __all__ = ( ) if TYPE_CHECKING: + from .types.threads import ThreadArchiveDuration from .role import Role from .member import Member, VoiceState - from .abc import Snowflake + from .abc import Snowflake, SnowflakeTime from .message import Message from .webhook import Webhook - from .abc import SnowflakeTime async def _single_delete_strategy(messages): for m in messages: @@ -586,6 +589,80 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): from .message import PartialMessage return PartialMessage(channel=self, id=message_id) + async def start_private_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = 1440) -> Thread: + """|coro| + + Starts a private thread in this text channel. + + You must have :attr:`~discord.Permissions.send_messages` and + :attr:`~discord.Permissions.use_private_threads` in order to start a thread. + + Parameters + ----------- + name: :class:`str` + The name of the thread. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + Defaults to ``1440`` or 24 hours. + + Raises + ------- + Forbidden + You do not have permissions to start a thread. + HTTPException + Starting the thread failed. + """ + + data = await self._state.http.start_public_thread( + self.id, + name=name, + auto_archive_duration=auto_archive_duration, + type=ChannelType.private_thread.value, + ) + return Thread(guild=self.guild, data=data) + + async def archive_threads( + self, + *, + private: bool = True, + joined: bool = False, + limit: Optional[int] = 50, + before: Optional[Union[Snowflake, datetime.datetime]] = None, + ) -> ArchivedThreadIterator: + """Returns an :class:`~discord.AsyncIterator` that iterates over all archived threads in the guild. + + You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads + then :attr:`~Permissions.manage_messages` is also required. + + Parameters + ----------- + limit: Optional[:class:`bool`] + The number of threads to retrieve. + If ``None``, retrieves every archived thread in the channel. Note, however, + that this would make it a slow operation. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve archived channels before the given date or ID. + private: :class:`bool` + Whether to retrieve private archived threads. + joined: :class:`bool` + Whether to retrieve private archived threads that you've joined. + You cannot set ``joined`` to ``True`` and ``private`` to ``False``. + + Raises + ------ + Forbidden + You do not have permissions to get archived threads. + HTTPException + The request to get the archived threads failed. + + Yields + ------- + :class:`Thread` + The archived threads. + """ + return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before) + + class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): __slots__ = ('name', 'id', 'guild', 'bitrate', 'user_limit', '_state', 'position', '_overwrites', 'category_id', diff --git a/discord/http.py b/discord/http.py index ee93d1a92..1d27996c4 100644 --- a/discord/http.py +++ b/discord/http.py @@ -785,11 +785,17 @@ class HTTPClient: route = Route('DELETE', '/channels/{channel_id}/thread-members/{user_id}', channel_id=channel_id, user_id=user_id) return self.request(route) - def get_archived_threads(self, channel_id: int, before=None, limit: int = 50, public: bool = True): - if public: - route = Route('GET', '/channels/{channel_id}/threads/archived/public', channel_id=channel_id) - else: - route = Route('GET', '/channels/{channel_id}/threads/archived/private', channel_id=channel_id) + def get_public_archived_threads(self, channel_id: int, before=None, limit: int = 50): + route = Route('GET', '/channels/{channel_id}/threads/archived/public', channel_id=channel_id) + + params = {} + if before: + params['before'] = before + params['limit'] = limit + return self.request(route, params=params) + + def get_private_archived_threads(self, channel_id: int, before=None, limit: int = 50): + route = Route('GET', '/channels/{channel_id}/threads/archived/private', channel_id=channel_id) params = {} if before: diff --git a/discord/iterators.py b/discord/iterators.py index cc082f55c..28e98ec46 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -29,7 +29,7 @@ import datetime from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator from .errors import NoMoreItems -from .utils import time_snowflake, maybe_coroutine +from .utils import snowflake_time, time_snowflake, maybe_coroutine from .object import Object from .audit_logs import AuditLogEntry @@ -55,11 +55,17 @@ if TYPE_CHECKING: PartialUser as PartialUserPayload, ) + from .types.threads import ( + Thread as ThreadPayload, + ) + from .member import Member from .user import User from .message import Message from .audit_logs import AuditLogEntry from .guild import Guild + from .threads import Thread + from .abc import Snowflake T = TypeVar('T') OT = TypeVar('OT') @@ -655,3 +661,92 @@ class MemberIterator(_AsyncIterator['Member']): from .member import Member return Member(data=data, guild=self.guild, state=self.state) + + +class ArchivedThreadIterator(_AsyncIterator['Thread']): + def __init__( + self, + channel_id: int, + guild: Guild, + limit: Optional[int], + joined: bool, + private: bool, + before: Optional[Union[Snowflake, datetime.datetime]] = None, + ): + self.channel_id = channel_id + self.guild = guild + self.limit = limit + self.joined = joined + self.private = private + self.http = guild._state.http + + if joined and not private: + raise ValueError('Cannot iterate over joined public archived threads') + + self.before: Optional[str] + if before is None: + self.before = None + elif isinstance(before, datetime.datetime): + if joined: + self.before = str(time_snowflake(before, high=False)) + else: + self.before = before.isoformat() + else: + if joined: + self.before = str(before.id) + else: + self.before = snowflake_time(before.id).isoformat() + + self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp + + if joined: + self.endpoint = self.http.get_joined_private_archived_threads + self.update_before = self.get_thread_id + elif private: + self.endpoint = self.http.get_private_archived_threads + else: + self.endpoint = self.http.get_archived_threads + + self.queue: asyncio.Queue[Thread] = asyncio.Queue() + self.has_more: bool = True + + async def next(self) -> Thread: + if self.queue.empty(): + await self.fill_queue() + + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + @staticmethod + def get_archive_timestamp(data: ThreadPayload) -> str: + return data['thread_metadata']['archive_timestamp'] + + @staticmethod + def get_thread_id(data: ThreadPayload) -> str: + return data['id'] # type: ignore + + async def fill_queue(self) -> None: + if not self.has_more: + raise NoMoreItems() + + limit = 50 if self.limit is None else max(self.limit, 50) + data = await self.endpoint(self.channel_id, before=self.before, limit=limit) + + # This stuff is obviously WIP because 'members' is always empty + threads: List[ThreadPayload] = data.get('threads', []) + for d in reversed(threads): + self.queue.put_nowait(self.create_thread(d)) + + self.has_more = data.get('has_more', False) + if self.limit is not None: + self.limit -= len(threads) + if self.limit <= 0: + self.has_more = False + + if self.has_more: + self.before = self.update_before(threads[-1]) + + def create_thread(self, data: ThreadPayload) -> Thread: + return Thread(guild=self.guild, data=data) diff --git a/discord/message.py b/discord/message.py index 42ef6bf7d..dc240d296 100644 --- a/discord/message.py +++ b/discord/message.py @@ -46,6 +46,7 @@ from .utils import escape_mentions from .guild import Guild from .mixins import Hashable from .sticker import Sticker +from .threads import Thread if TYPE_CHECKING: from .types.message import ( @@ -58,7 +59,7 @@ if TYPE_CHECKING: ) from .types.components import Component as ComponentPayload - + from .types.threads import ThreadArchiveDuration from .types.member import Member as MemberPayload from .types.user import User as UserPayload from .types.embed import Embed as EmbedPayload @@ -79,7 +80,6 @@ __all__ = ( 'DeletedReferencedMessage', ) - def convert_emoji_reaction(emoji): if isinstance(emoji, Reaction): emoji = emoji.emoji @@ -1429,6 +1429,45 @@ class Message(Hashable): """ await self._state.http.clear_reactions(self.channel.id, self.id) + async def start_public_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = 1440) -> Thread: + """|coro| + + Starts a public thread from this message. + + You must have :attr:`~discord.Permissions.send_messages` and + :attr:`~discord.Permissions.use_threads` in order to start a thread. + + The channel this message belongs in must be a :class:`TextChannel`. + + Parameters + ----------- + name: :class:`str` + The name of the thread. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + Defaults to ``1440`` or 24 hours. + + Raises + ------- + Forbidden + You do not have permissions to start a thread. + HTTPException + Starting the thread failed. + InvalidArgument + This message does not have guild info attached. + """ + if self.guild is None: + raise InvalidArgument('This message does not have guild info attached.') + + data = await self._state.http.start_public_thread( + self.channel.id, + self.id, + name=name, + auto_archive_duration=auto_archive_duration, + type=ChannelType.public_thread.value, + ) + return Thread(guild=self.guild, data=data) # type: ignore + async def reply(self, content: Optional[str] = None, **kwargs) -> Message: """|coro| diff --git a/discord/threads.py b/discord/threads.py index cde4a3299..0d2b6f960 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -40,10 +40,13 @@ if TYPE_CHECKING: Thread as ThreadPayload, ThreadMember as ThreadMemberPayload, ThreadMetadata, + ThreadArchiveDuration, ) from .guild import Guild from .channel import TextChannel from .member import Member + from .message import Message + from .abc import Snowflake class Thread(Messageable, Hashable): @@ -171,7 +174,7 @@ class Thread(Messageable, Hashable): return self.guild.get_member(self.owner_id) @property - def last_message(self): + def last_message(self) -> Optional[Message]: """Fetches the last message from this channel in cache. The message might not be valid or point to an existing message. @@ -191,6 +194,140 @@ class Thread(Messageable, Hashable): """ return self._state._get_message(self.last_message_id) if self.last_message_id else None + def is_private(self) -> bool: + """:class:`bool`: Whether the thread is a private thread.""" + return self.type is ChannelType.private_thread + + async def edit( + self, + *, + name: str = ..., + archived: bool = ..., + auto_archive_duration: ThreadArchiveDuration = ..., + ): + """|coro| + + Edits the thread. + + To unarchive a thread :attr:`~.Permissions.send_messages` is required. Otherwise, + :attr:`~.Permissions.manage_messages` is required to edit the thread. + + Parameters + ------------ + name: :class:`str` + The new name of the thread. + archived: :class:`bool` + Whether to archive the thread or not. + auto_archive_duration: :class:`int` + The new duration to auto archive threads for inactivity. + + Raises + ------- + Forbidden + You do not have permissions to edit the thread. + HTTPException + Editing the thread failed. + """ + payload = {} + if name is not ...: + payload['name'] = str(name) + if archived is not ...: + payload['archived'] = archived + if auto_archive_duration is not ...: + payload['auto_archive_duration'] = auto_archive_duration + await self._state.http.edit_channel(self.id, **payload) + + async def join(self): + """|coro| + + Joins this thread. + + You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads` + to join a public thread. If the thread is private then :attr:`~Permissions.send_messages` + and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages` + is required to join the thread. + + Raises + ------- + Forbidden + You do not have permissions to join the thread. + HTTPException + Joining the thread failed. + """ + await self._state.http.join_thread(self.id) + + async def leave(self): + """|coro| + + Leaves this thread. + + Raises + ------- + HTTPException + Leaving the thread failed. + """ + await self._state.http.leave_thread(self.id) + + async def add_user(self, user: Snowflake): + """|coro| + + Adds a user to this thread. + + You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads` + to add a user to a public thread. If the thread is private then :attr:`~Permissions.send_messages` + and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages` + is required to add a user to the thread. + + Parameters + ----------- + user: :class:`abc.Snowflake` + The user to add to the thread. + + Raises + ------- + Forbidden + You do not have permissions to add the user to the thread. + HTTPException + Adding the user to the thread failed. + """ + await self._state.http.add_user_to_thread(self.id, user.id) + + async def remove_user(self, user: Snowflake): + """|coro| + + Removes a user from this thread. + + You must have :attr:`~Permissions.manage_messages` or be the creator of the thread to remove a user. + + Parameters + ----------- + user: :class:`abc.Snowflake` + The user to add to the thread. + + Raises + ------- + Forbidden + You do not have permissions to remove the user from the thread. + HTTPException + Removing the user from the thread failed. + """ + await self._state.http.remove_user_from_thread(self.id, user.id) + + async def delete(self): + """|coro| + + Deletes this thread. + + You must have :attr:`~Permissions.manage_channels` to delete threads. + + Raises + ------- + Forbidden + You do not have permissions to delete this thread. + HTTPException + Deleting the thread failed. + """ + await self._state.http.delete_channel(self.id) class ThreadMember(Hashable): """Represents a Discord thread member.