From d20b444bfb54f01ce001cbdd02f7f30a60d9bc78 Mon Sep 17 00:00:00 2001 From: dolfies Date: Mon, 15 Nov 2021 20:55:29 -0500 Subject: [PATCH] Preliminary thread support --- discord/channel.py | 13 +-- discord/guild.py | 10 +-- discord/message.py | 2 +- discord/state.py | 83 ++++++++++++----- discord/threads.py | 217 +++++++++++++++++++++++---------------------- 5 files changed, 174 insertions(+), 151 deletions(-) diff --git a/discord/channel.py b/discord/channel.py index d21ae97ab..1063fd2e5 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -89,15 +89,6 @@ if TYPE_CHECKING: from .types.snowflake import SnowflakeList -async def _delete_messages(state, channel_id, messages): - delete_message = state.http.delete_message - for msg in messages: - try: - await delete_message(channel_id, msg.id) - except NotFound: - pass - - class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. @@ -394,9 +385,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): messages = list(messages) if len(messages) == 0: - return # do nothing + return # Do nothing - await _delete_messages(self._state, self.id, messages) + await self._state._delete_messages(self.id, messages) async def purge( self, diff --git a/discord/guild.py b/discord/guild.py index 49e8c0089..5b06f0ecf 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -346,19 +346,13 @@ class Guild(Hashable): def _remove_thread(self, thread: Snowflake, /) -> None: self._threads.pop(thread.id, None) - def _clear_threads(self) -> None: - self._threads.clear() - def _remove_threads_by_channel(self, channel_id: int) -> None: to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id] for k in to_remove: del self._threads[k] def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]: - to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids} - for k in to_remove: - del self._threads[k] - return to_remove + return {k: t for k, t in self._threads.items() if t.parent_id in channel_ids} def __str__(self) -> str: return self.name or '' @@ -2856,7 +2850,7 @@ class Guild(Hashable): asyncio.TimeoutError The query timed out waiting for the members. ValueError - Invalid parameters were passed to the function + Invalid parameters were passed to the function. Returns -------- diff --git a/discord/message.py b/discord/message.py index 615d37522..a7ceceb78 100644 --- a/discord/message.py +++ b/discord/message.py @@ -898,7 +898,7 @@ class Message(Hashable): self.call = CallMessage(message=self, **call) def _handle_components(self, components: List[ComponentPayload]): - self.components = [_component_factory(d) for d in components] + self.components = [_component_factory(d, self) for d in components] def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None: self.guild = new_guild diff --git a/discord/state.py b/discord/state.py index a3093ae17..274d9e724 100644 --- a/discord/state.py +++ b/discord/state.py @@ -35,6 +35,7 @@ import time import os import random +from .errors import NotFound from .guild import Guild from .activity import BaseActivity from .user import User, ClientUser @@ -531,6 +532,14 @@ class ConnectionState: return channel or PartialMessageable(state=self, id=channel_id), guild + async def _delete_messages(self, channel_id, messages): + delete_message = self.http.delete_message + for msg in messages: + try: + await delete_message(channel_id, msg.id) + except NotFound: + pass + def request_guild(self, guild_id: int) -> None: return self.ws.request_lazy_guild(guild_id, typing=True, activities=True, threads=True) @@ -679,7 +688,7 @@ class ConnectionState: if guild_id in self._unavailable_guilds: # I don't know how I feel about this :( return - # Channel will be the correct type here + # channel will be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore self.dispatch('message', message) if self._messages is not None: @@ -944,10 +953,10 @@ class ConnectionState: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) return - thread = Thread(guild=guild, state=guild._state, data=data) + thread = Thread(guild=guild, state=self, data=data) has_thread = guild.get_thread(thread.id) guild._add_thread(thread) if not has_thread: @@ -988,49 +997,67 @@ class ConnectionState: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding.', guild_id) return try: channel_ids = set(data['channel_ids']) except KeyError: - # If not provided, then the entire guild is being synced - # So all previous thread data should be overwritten - previous_threads = guild._threads.copy() - guild._clear_threads() + channel_ids = None + threads = guild._threads.copy() else: - previous_threads = guild._filter_threads(channel_ids) - - threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])} + threads = guild._filter_threads(channel_ids) + + new_threads = {} + for d in data.get('threads', []): + if (thread := threads.pop(int(d['id']), None)) is not None: + old = thread._update(d) + if old is not None: # None = wasn't updated + self.dispatch('thread_update', old, thread) + else: + thread = Thread(guild=guild, state=self, data=d) + new_threads[thread.id] = thread + old_threads = [t for t in threads.values() if t not in new_threads] for member in data.get('members', []): try: - # note: member['id'] is the thread_id + # Note: member['id'] is the thread_id thread = threads[member['id']] except KeyError: continue else: thread._add_member(ThreadMember(thread, member)) - for thread in threads.values(): - old = previous_threads.pop(thread.id, None) - if old is None: - self.dispatch('thread_join', thread) + for k in new_threads.values(): + guild._add_thread(k) + self.dispatch('thread_join', k) - for thread in previous_threads.values(): - self.dispatch('thread_remove', thread) + for k in old_threads: + del guild._threads[k.id] + self.dispatch('thread_remove', k) + + for message in data.get('most_recent_messages', []): + guild_id = utils._get_as_snowflake(message, 'guild_id') + channel, _ = self._get_guild_channel(message) + if guild_id in self._unavailable_guilds: # I don't know how I feel about this :( + continue + + # channel will be the correct type here + message = Message(channel=channel, data=message, state=self) # type: ignore + if self._messages is not None: + self._messages.append(message) def parse_thread_member_update(self, data) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) return thread_id = int(data['id']) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding.', thread_id) return member = ThreadMember(thread, data) @@ -1040,13 +1067,13 @@ class ConnectionState: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) return thread_id = int(data['id']) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding.', thread_id) return added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])] @@ -1061,8 +1088,8 @@ class ConnectionState: self.dispatch('thread_join', thread) for member_id in removed_member_ids: + member = thread._pop_member(member_id) if member_id != self_id: - member = thread._pop_member(member_id) if member is not None: self.dispatch('thread_member_remove', member) else: @@ -1085,6 +1112,11 @@ class ConnectionState: # self.dispatch('member_join', member) + if (presence := data.get('presence')) is not None: + old_member = copy.copy(member) + member._presence_update(presence, tuple()) + self.dispatch('presence_update', old_member, member) + def parse_guild_member_remove(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: @@ -1130,6 +1162,11 @@ class ConnectionState: guild._add_member(member) _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + if (presence := data.get('presence')) is not None: + member._presence_update(presence, tuple()) + if old_member is not None: + self.dispatch('presence_update', old_member, member) + def parse_guild_sync(self, data) -> None: print('I noticed you triggered a `GUILD_SYNC`.\nIf you want to share your secrets, please feel free to email me.') diff --git a/discord/threads.py b/discord/threads.py index dfc89a1cb..39d6340ca 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -25,13 +25,13 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING -import time import asyncio +import copy from .mixins import Hashable from .abc import Messageable from .enums import ChannelType, try_enum -from .errors import ClientException +from .errors import ClientException, InvalidData from .utils import MISSING, parse_time, _get_as_snowflake __all__ = ( @@ -89,9 +89,9 @@ class Thread(Messageable, Hashable): id: :class:`int` The thread ID. parent_id: :class:`int` - The parent :class:`TextChannel` ID this thread belongs to. + The ID of the parent :class:`TextChannel` this thread belongs to. owner_id: :class:`int` - The user's ID that created this thread. + The ID of the user that created this thread. last_message_id: Optional[:class:`int`] The last message ID of the message sent to this thread. It may *not* point to an existing or valid message. @@ -104,9 +104,6 @@ class Thread(Messageable, Hashable): An approximate number of messages in this thread. This caps at 50. member_count: :class:`int` An approximate number of members in this thread. This caps at 50. - me: Optional[:class:`ThreadMember`] - A thread member representing yourself, if you've joined the thread. - This could not be available. archived: :class:`bool` Whether the thread is archived. locked: :class:`bool` @@ -115,7 +112,7 @@ class Thread(Messageable, Hashable): Whether non-moderators can add other non-moderators to this thread. This is always ``True`` for public threads. archiver_id: Optional[:class:`int`] - The user's ID that archived this thread. + The ID of the user that archived this thread. auto_archive_duration: :class:`int` The duration in minutes until the thread is automatically archived due to inactivity. Usually a value of 60, 1440, 4320 and 10080. @@ -136,13 +133,13 @@ class Thread(Messageable, Hashable): 'message_count', 'member_count', 'slowmode_delay', - 'me', 'locked', 'archived', 'invitable', 'archiver_id', 'auto_archive_duration', 'archive_timestamp', + 'member_ids', ) def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): @@ -173,12 +170,13 @@ class Thread(Messageable, Hashable): self.slowmode_delay = data.get('rate_limit_per_user', 0) self.message_count = data['message_count'] self.member_count = data['member_count'] + self.member_ids = data['member_ids_preview'] self._unroll_metadata(data['thread_metadata']) try: member = data['member'] except KeyError: - self.me = None + pass else: self.me = ThreadMember(self, member) @@ -191,17 +189,26 @@ class Thread(Messageable, Hashable): self.invitable = data.get('invitable', True) def _update(self, data): - try: - self.name = data['name'] - except KeyError: - pass - + old = copy.copy(self) self.slowmode_delay = data.get('rate_limit_per_user', 0) - try: - self._unroll_metadata(data['thread_metadata']) - except KeyError: - pass + if (meta := data.get('thread_metadata')) is not None: + self._unroll_metadata(meta) + if (name := data.get('name')) is not None: + self.name = name + if (last_message_id := _get_as_snowflake(data, 'last_message_id')) is not None: + self.last_message_id = last_message_id + if (message_count := data.get('message_count')) is not None: + self.message_count = message_count + if (member_count := data.get('member_count')) is not None: + self.member_count = member_count + if (member_ids := data.get('member_ids_preview')) is not None: + self.member_ids = member_ids + + attrs = [x for x in self.__slots__ if not any(y in x for y in ('member', 'guild', 'state', 'count'))] + + if any(getattr(self, attr) != getattr(old, attr) for attr in attrs): + return old @property def type(self) -> ChannelType: @@ -218,6 +225,11 @@ class Thread(Messageable, Hashable): """Optional[:class:`Member`]: The member this thread belongs to.""" return self.guild.get_member(self.owner_id) + @property + def archiver(self) -> Optional[Member]: + """Optional[:class:`Member`]: The member that archived this thread.""" + return self.guild.get_member(self.archiver_id) + @property def mention(self) -> str: """:class:`str`: The string that allows you to mention the thread.""" @@ -227,9 +239,8 @@ class Thread(Messageable, Hashable): def members(self) -> List[ThreadMember]: """List[:class:`ThreadMember`]: A list of thread members in this thread. - This requires :attr:`Intents.members` to be properly filled. Most of the time however, - this data is not provided by the gateway and a call to :meth:`fetch_members` is - needed. + Initial members are not provided by Discord. You must call :func:`fetch_members` + or have thread subscribing enabled. """ return list(self._members.values()) @@ -294,6 +305,19 @@ class Thread(Messageable, Hashable): raise ClientException('Parent channel not found') return parent.category_id + @property + def me(self) -> Optional[ThreadMember]: + """Optional[:class:`ThreadMember`]: A thread member representing yourself, if you've joined the thread. + + This might not be available. + """ + self_id = self._state.user.id + return self._members.get(self_id) + + @me.setter + def me(self, member) -> None: + self._members[member.id] = member + def is_private(self) -> bool: """:class:`bool`: Whether the thread is a private thread. @@ -357,17 +381,13 @@ class Thread(Messageable, Hashable): Deletes a list of messages. This is similar to :meth:`Message.delete` except it bulk deletes multiple messages. - As a special case, if the number of messages is 0, then nothing - is done. If the number of messages is 1 then single message - delete is done. If it's more than two, then bulk delete is used. - - You cannot bulk delete more than 100 messages or messages that - are older than 14 days old. - You must have the :attr:`~Permissions.manage_messages` permission to - use this. + use this (unless they're your own). - Usable only by bot accounts. + .. note:: + Users do not have access to the message bulk-delete endpoint. + Since messages are just iterated over and deleted one-by-one, + it's easy to get ratelimited using this method. Parameters ----------- @@ -376,13 +396,8 @@ class Thread(Messageable, Hashable): Raises ------ - ClientException - The number of messages to delete was more than 100. Forbidden - You do not have proper permissions to delete the messages or - you're not using a bot account. - NotFound - If single delete, then the message was already deleted. + You do not have proper permissions to delete the messages. HTTPException Deleting the messages failed. """ @@ -390,18 +405,9 @@ class Thread(Messageable, Hashable): messages = list(messages) if len(messages) == 0: - return # do nothing - - if len(messages) == 1: - message_id = messages[0].id - await self._state.http.delete_message(self.id, message_id) - return - - if len(messages) > 100: - raise ClientException('Can only bulk delete messages up to 100 messages') + return # Do nothing - message_ids: SnowflakeList = [m.id for m in messages] - await self._state.http.delete_messages(self.id, message_ids) + await self._state._delete_messages(self.id, messages) async def purge( self, @@ -412,7 +418,6 @@ class Thread(Messageable, Hashable): after: Optional[SnowflakeTime] = None, around: Optional[SnowflakeTime] = None, oldest_first: Optional[bool] = False, - bulk: bool = True, ) -> List[Message]: """|coro| @@ -420,10 +425,8 @@ class Thread(Messageable, Hashable): ``check``. If a ``check`` is not provided then all messages are deleted without discrimination. - You must have the :attr:`~Permissions.manage_messages` permission to - delete messages even if they are your own (unless you are a user - account). The :attr:`~Permissions.read_message_history` permission is - also needed to retrieve message history. + The :attr:`~Permissions.read_message_history` permission is needed to + retrieve message history. Examples --------- @@ -433,8 +436,8 @@ class Thread(Messageable, Hashable): def is_me(m): return m.author == client.user - deleted = await thread.purge(limit=100, check=is_me) - await thread.send(f'Deleted {len(deleted)} message(s)') + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f'Deleted {len(deleted)} message(s)') Parameters ----------- @@ -452,10 +455,6 @@ class Thread(Messageable, Hashable): Same as ``around`` in :meth:`history`. oldest_first: Optional[:class:`bool`] Same as ``oldest_first`` in :meth:`history`. - bulk: :class:`bool` - If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting - a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will - fall back to single delete if messages are older than two weeks. Raises ------- @@ -469,54 +468,30 @@ class Thread(Messageable, Hashable): List[:class:`.Message`] The list of messages that were deleted. """ - if check is MISSING: check = lambda m: True + state = self._state + channel_id = self.id iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) ret: List[Message] = [] count = 0 - minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 - - async def _single_delete_strategy(messages: Iterable[Message]): - for m in messages: - await m.delete() - - strategy = self.delete_messages if bulk else _single_delete_strategy - async for message in iterator: - if count == 100: - to_delete = ret[-100:] - await strategy(to_delete) + if count == 50: + to_delete = ret[-50:] + await state._delete_messages(channel_id, to_delete) count = 0 - await asyncio.sleep(1) if not check(message): continue - if message.id < minimum_time: - # older than 14 days old - if count == 1: - await ret[-1].delete() - elif count >= 2: - to_delete = ret[-count:] - await strategy(to_delete) - - count = 0 - strategy = _single_delete_strategy - count += 1 ret.append(message) - # SOme messages remaining to poll - if count >= 2: - # more than 2 messages -> bulk delete - to_delete = ret[-count:] - await strategy(to_delete) - elif count == 1: - # delete a single message - await ret[-1].delete() + # Some messages remaining to poll + to_delete = ret[-count:] + await state._delete_messages(channel_id, to_delete) return ret @@ -666,23 +641,32 @@ class Thread(Messageable, Hashable): async def fetch_members(self) -> List[ThreadMember]: """|coro| - Retrieves all :class:`ThreadMember` that are in this thread. - - This requires :attr:`Intents.members` to get information about members - other than yourself. + Retrieves all :class:`ThreadMember` that are in this thread, + along with their respective :class:`Member`. Raises ------- - HTTPException - Retrieving the members failed. + InvalidData + Discord didn't respond with the members. Returns -------- List[:class:`ThreadMember`] All thread members in the thread. """ - members = await self._state.http.get_thread_members(self.id) - return [ThreadMember(parent=self, data=data) for data in members] + state = self._state + await state.ws.request_lazy_guild(self.parent.guild.id, thread_member_lists=[self.id]) # type: ignore + future = state.ws.wait_for('THREAD_MEMBER_LIST_UPDATE', lambda d: int(d['thread_id']) == self.id) + try: + data = await asyncio.wait_for(future, timeout=30) + except asyncio.TimeoutError as exc: + raise InvalidData('Failed to retrieve members') from exc + + members = [ThreadMember(self, {'member': member}) for member in data['members']] + for m in members: + self._add_member(m) + + return self.members # Includes correct self.me async def delete(self): """|coro| @@ -718,13 +702,13 @@ class Thread(Messageable, Hashable): :class:`PartialMessage` The partial message. """ - from .message import PartialMessage return PartialMessage(channel=self, id=message_id) def _add_member(self, member: ThreadMember) -> None: - self._members[member.id] = member + if member.id != self._state.self_id: + self._members[member.id] = member def _pop_member(self, member_id: int) -> Optional[ThreadMember]: return self._members.pop(member_id, None) @@ -759,8 +743,12 @@ class ThreadMember(Hashable): The thread member's ID. thread_id: :class:`int` The thread's ID. - joined_at: :class:`datetime.datetime` + joined_at: Optional[:class:`datetime.datetime`] The time the member joined the thread in UTC. + Only reliably available for yourself or members joined while the user is connected to the gateway. + flags: :class:`int` + The thread member's flags. Will be its own class in the future. + Only reliably available for yourself or members joined while the user is connected to the gateway. """ __slots__ = ( @@ -781,19 +769,32 @@ class ThreadMember(Hashable): return f'' def _from_data(self, data: ThreadMemberPayload): + state = self._state + try: self.id = int(data['user_id']) except KeyError: - assert self._state.self_id is not None - self.id = self._state.self_id + assert state.self_id is not None + self.id = state.self_id try: self.thread_id = int(data['id']) except KeyError: self.thread_id = self.parent.id - self.joined_at = parse_time(data['join_timestamp']) - self.flags = data['flags'] + self.joined_at = parse_time(data.get('join_timestamp')) + self.flags = data.get('flags') + + if (data := data.get('member')) is not None: + guild = self.parent.parent.guild # type: ignore + mdata = data['member'] + mdata['guild_id'] = guild.id + self.id = user_id = int(data['user_id']) + mdata['presence'] = data.get('presence') + if guild.get_member(user_id) is not None: + state.parse_guild_member_update(mdata) + else: + state.parse_guild_member_add(mdata) @property def thread(self) -> Thread: @@ -802,7 +803,7 @@ class ThreadMember(Hashable): @property def member(self) -> Optional[Member]: - """Optional[:class:`Member`]: The member this member represents. If the member + """Optional[:class:`Member`]: The member this :class:`ThreadMember` represents. If the member is not cached then this will be ``None``. """ - return self.parent.guild.get_member(self.id) + return self.parent.parent.guild.get_member(self.id) # type: ignore