From bd369c76ea5424f65e37d84bfb45df1c76a4e739 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sun, 9 May 2021 22:23:21 -0400 Subject: [PATCH] Parse remaining thread events. --- discord/guild.py | 16 +++++++++++++- discord/state.py | 55 +++++++++++++++++++++++++++++++++++++++++----- discord/threads.py | 2 ++ docs/api.rst | 50 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 114 insertions(+), 9 deletions(-) diff --git a/discord/guild.py b/discord/guild.py index 4e7e38f67..e32c2fcf7 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -26,7 +26,7 @@ from __future__ import annotations import copy from collections import namedtuple -from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload +from typing import Dict, List, Set, Literal, Optional, TYPE_CHECKING, Union, overload from . import utils, abc from .role import Role @@ -227,6 +227,20 @@ class Guild(Hashable): def _remove_thread(self, thread): self._threads.pop(thread.id, None) + def _clear_threads(self): + self._threads.clear() + + def _remove_threads_by_channel(self, channel_id: int): + to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id] + for k in to_remove: + del self._threads[k] + + def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]: + to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids} + for k in to_remove: + del self._threads[k] + return to_remove + def __str__(self): return self.name or '' diff --git a/discord/state.py b/discord/state.py index 85a94cee9..d8c322ace 100644 --- a/discord/state.py +++ b/discord/state.py @@ -716,7 +716,7 @@ class ConnectionState: thread = Thread(guild=guild, data=data) guild._add_thread(thread) - self.dispatch('thread_create', thread) + self.dispatch('thread_join', thread) def parse_thread_update(self, data): guild_id = int(data['guild_id']) @@ -752,6 +752,16 @@ class ConnectionState: log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id) return + try: + channel_ids = set(data['channel_ids']) + except KeyError: + # If not provided, then the entire guild is being synced + # So all previous thread data should be overwritten + previous_threads = guild._threads.copy() + guild._clear_threads() + else: + previous_threads = guild._filter_threads(channel_ids) + threads = { d['id']: guild._store_thread(d) for d in data.get('threads', []) @@ -766,7 +776,13 @@ class ConnectionState: else: thread._add_member(ThreadMember(thread, member)) - # TODO: dispatch? + for thread in threads.values(): + old = previous_threads.pop(thread.id, None) + if old is None: + self.dispatch('thread_join', thread) + + for thread in previous_threads.values(): + self.dispatch('thread_remove', thread) def parse_thread_member_update(self, data): guild_id = int(data['guild_id']) @@ -776,15 +792,44 @@ class ConnectionState: return thread_id = int(data['id']) - thread = guild.get_thread(thread_id) + thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) return member = ThreadMember(thread, data) - thread._add_member(member) + thread.me = member + + def parse_thread_members_update(self, data): + guild_id = int(data['guild_id']) + guild: Optional[Guild] = self._get_guild(guild_id) + if guild is None: + log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + return - # TODO: dispatch + thread_id = int(data['id']) + thread: Optional[Thread] = guild.get_thread(thread_id) + if thread is None: + log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + return + + added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])] + removed_member_ids = data.get('removed_member_ids', []) + self_id = self.self_id + for member in added_members: + if member.id != self_id: + thread._add_member(member) + self.dispatch('thread_member_join', member) + else: + thread.me = member + self.dispatch('thread_join', thread) + + for member_id in removed_member_ids: + if member_id != self_id: + member = thread._pop_member(member_id) + self.dispatch('thread_member_leave', member) + else: + self.dispatch('thread_remove', thread) def parse_guild_member_add(self, data): guild = self._get_guild(int(data['guild_id'])) diff --git a/discord/threads.py b/discord/threads.py index 1a8a6af1e..cf6d92aa8 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -383,6 +383,8 @@ class Thread(Messageable, Hashable): def _add_member(self, member: ThreadMember) -> None: self._members[member.id] = member + def _pop_member(self, member_id: int) -> Optional[ThreadMember]: + return self._members.pop(member_id, None) class ThreadMember(Hashable): """Represents a Discord thread member. diff --git a/docs/api.rst b/docs/api.rst index 4282373ec..395de4771 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -658,10 +658,42 @@ to handle it, which defaults to print a traceback and ignoring the exception. :param last_pin: The latest message that was pinned as an aware datetime in UTC. Could be ``None``. :type last_pin: Optional[:class:`datetime.datetime`] +.. function:: on_thread_join(thread) + + Called whenever a thread is joined. + + Note that you can get the guild from :attr:`Thread.guild`. + + This requires :attr:`Intents.guilds` to be enabled. + + .. versionadded:: 2.0 + + :param thread: The thread that got joined. + :type thread: :class:`Thread` + +.. function:: on_thread_remove(thread) + + Called whenever a thread is removed. This is different from a thread being deleted. + + Note that you can get the guild from :attr:`Thread.guild`. + + This requires :attr:`Intents.guilds` to be enabled. + + .. warning:: + + Due to technical limitations, this event might not be called + as soon as one expects. Since the library tracks thread membership + locally, the API only sends updated thread membership status upon being + synced by joining a thread. + + .. versionadded:: 2.0 + + :param thread: The thread that got removed. + :type thread: :class:`Thread` + .. function:: on_thread_delete(thread) - on_thread_create(thread) - Called whenever a thread is deleted or created. + Called whenever a thread is deleted. Note that you can get the guild from :attr:`Thread.guild`. @@ -669,9 +701,21 @@ to handle it, which defaults to print a traceback and ignoring the exception. .. versionadded:: 2.0 - :param thread: The thread that got created or deleted. + :param thread: The thread that got deleted. :type thread: :class:`Thread` +.. function:: on_thread_member_join(member) + on_thread_member_remove(member) + + Called when a :class:`ThreadMember` leaves or joins a :class:`Thread`. + + You can get the thread a member belongs in by accessing :attr:`ThreadMember.thread`. + + This requires :attr:`Intents.members` to be enabled. + + :param member: The member who joined or left. + :type member: :class:`ThreadMember` + .. function:: on_thread_update(before, after) Called whenever a thread is updated.