Browse Source

Preliminary thread support

pull/10109/head
dolfies 3 years ago
parent
commit
d20b444bfb
  1. 13
      discord/channel.py
  2. 10
      discord/guild.py
  3. 2
      discord/message.py
  4. 83
      discord/state.py
  5. 217
      discord/threads.py

13
discord/channel.py

@ -89,15 +89,6 @@ if TYPE_CHECKING:
from .types.snowflake import SnowflakeList
async def _delete_messages(state, channel_id, messages):
delete_message = state.http.delete_message
for msg in messages:
try:
await delete_message(channel_id, msg.id)
except NotFound:
pass
class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""Represents a Discord guild text channel.
@ -394,9 +385,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
messages = list(messages)
if len(messages) == 0:
return # do nothing
return # Do nothing
await _delete_messages(self._state, self.id, messages)
await self._state._delete_messages(self.id, messages)
async def purge(
self,

10
discord/guild.py

@ -346,19 +346,13 @@ class Guild(Hashable):
def _remove_thread(self, thread: Snowflake, /) -> None:
self._threads.pop(thread.id, None)
def _clear_threads(self) -> None:
self._threads.clear()
def _remove_threads_by_channel(self, channel_id: int) -> None:
to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id]
for k in to_remove:
del self._threads[k]
def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]:
to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids}
for k in to_remove:
del self._threads[k]
return to_remove
return {k: t for k, t in self._threads.items() if t.parent_id in channel_ids}
def __str__(self) -> str:
return self.name or ''
@ -2856,7 +2850,7 @@ class Guild(Hashable):
asyncio.TimeoutError
The query timed out waiting for the members.
ValueError
Invalid parameters were passed to the function
Invalid parameters were passed to the function.
Returns
--------

2
discord/message.py

@ -898,7 +898,7 @@ class Message(Hashable):
self.call = CallMessage(message=self, **call)
def _handle_components(self, components: List[ComponentPayload]):
self.components = [_component_factory(d) for d in components]
self.components = [_component_factory(d, self) for d in components]
def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None:
self.guild = new_guild

83
discord/state.py

@ -35,6 +35,7 @@ import time
import os
import random
from .errors import NotFound
from .guild import Guild
from .activity import BaseActivity
from .user import User, ClientUser
@ -531,6 +532,14 @@ class ConnectionState:
return channel or PartialMessageable(state=self, id=channel_id), guild
async def _delete_messages(self, channel_id, messages):
delete_message = self.http.delete_message
for msg in messages:
try:
await delete_message(channel_id, msg.id)
except NotFound:
pass
def request_guild(self, guild_id: int) -> None:
return self.ws.request_lazy_guild(guild_id, typing=True, activities=True, threads=True)
@ -679,7 +688,7 @@ class ConnectionState:
if guild_id in self._unavailable_guilds: # I don't know how I feel about this :(
return
# Channel will be the correct type here
# channel will be the correct type here
message = Message(channel=channel, data=data, state=self) # type: ignore
self.dispatch('message', message)
if self._messages is not None:
@ -944,10 +953,10 @@ class ConnectionState:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
_log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id)
_log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return
thread = Thread(guild=guild, state=guild._state, data=data)
thread = Thread(guild=guild, state=self, data=data)
has_thread = guild.get_thread(thread.id)
guild._add_thread(thread)
if not has_thread:
@ -988,49 +997,67 @@ class ConnectionState:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
_log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id)
_log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding.', guild_id)
return
try:
channel_ids = set(data['channel_ids'])
except KeyError:
# If not provided, then the entire guild is being synced
# So all previous thread data should be overwritten
previous_threads = guild._threads.copy()
guild._clear_threads()
channel_ids = None
threads = guild._threads.copy()
else:
previous_threads = guild._filter_threads(channel_ids)
threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])}
threads = guild._filter_threads(channel_ids)
new_threads = {}
for d in data.get('threads', []):
if (thread := threads.pop(int(d['id']), None)) is not None:
old = thread._update(d)
if old is not None: # None = wasn't updated
self.dispatch('thread_update', old, thread)
else:
thread = Thread(guild=guild, state=self, data=d)
new_threads[thread.id] = thread
old_threads = [t for t in threads.values() if t not in new_threads]
for member in data.get('members', []):
try:
# note: member['id'] is the thread_id
# Note: member['id'] is the thread_id
thread = threads[member['id']]
except KeyError:
continue
else:
thread._add_member(ThreadMember(thread, member))
for thread in threads.values():
old = previous_threads.pop(thread.id, None)
if old is None:
self.dispatch('thread_join', thread)
for k in new_threads.values():
guild._add_thread(k)
self.dispatch('thread_join', k)
for thread in previous_threads.values():
self.dispatch('thread_remove', thread)
for k in old_threads:
del guild._threads[k.id]
self.dispatch('thread_remove', k)
for message in data.get('most_recent_messages', []):
guild_id = utils._get_as_snowflake(message, 'guild_id')
channel, _ = self._get_guild_channel(message)
if guild_id in self._unavailable_guilds: # I don't know how I feel about this :(
continue
# channel will be the correct type here
message = Message(channel=channel, data=message, state=self) # type: ignore
if self._messages is not None:
self._messages.append(message)
def parse_thread_member_update(self, data) -> None:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
_log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id)
_log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return
thread_id = int(data['id'])
thread: Optional[Thread] = guild.get_thread(thread_id)
if thread is None:
_log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
_log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding.', thread_id)
return
member = ThreadMember(thread, data)
@ -1040,13 +1067,13 @@ class ConnectionState:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
if guild is None:
_log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id)
_log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return
thread_id = int(data['id'])
thread: Optional[Thread] = guild.get_thread(thread_id)
if thread is None:
_log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
_log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding.', thread_id)
return
added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])]
@ -1061,8 +1088,8 @@ class ConnectionState:
self.dispatch('thread_join', thread)
for member_id in removed_member_ids:
member = thread._pop_member(member_id)
if member_id != self_id:
member = thread._pop_member(member_id)
if member is not None:
self.dispatch('thread_member_remove', member)
else:
@ -1085,6 +1112,11 @@ class ConnectionState:
# self.dispatch('member_join', member)
if (presence := data.get('presence')) is not None:
old_member = copy.copy(member)
member._presence_update(presence, tuple())
self.dispatch('presence_update', old_member, member)
def parse_guild_member_remove(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
@ -1130,6 +1162,11 @@ class ConnectionState:
guild._add_member(member)
_log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id)
if (presence := data.get('presence')) is not None:
member._presence_update(presence, tuple())
if old_member is not None:
self.dispatch('presence_update', old_member, member)
def parse_guild_sync(self, data) -> None:
print('I noticed you triggered a `GUILD_SYNC`.\nIf you want to share your secrets, please feel free to email me.')

217
discord/threads.py

@ -25,13 +25,13 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
import time
import asyncio
import copy
from .mixins import Hashable
from .abc import Messageable
from .enums import ChannelType, try_enum
from .errors import ClientException
from .errors import ClientException, InvalidData
from .utils import MISSING, parse_time, _get_as_snowflake
__all__ = (
@ -89,9 +89,9 @@ class Thread(Messageable, Hashable):
id: :class:`int`
The thread ID.
parent_id: :class:`int`
The parent :class:`TextChannel` ID this thread belongs to.
The ID of the parent :class:`TextChannel` this thread belongs to.
owner_id: :class:`int`
The user's ID that created this thread.
The ID of the user that created this thread.
last_message_id: Optional[:class:`int`]
The last message ID of the message sent to this thread. It may
*not* point to an existing or valid message.
@ -104,9 +104,6 @@ class Thread(Messageable, Hashable):
An approximate number of messages in this thread. This caps at 50.
member_count: :class:`int`
An approximate number of members in this thread. This caps at 50.
me: Optional[:class:`ThreadMember`]
A thread member representing yourself, if you've joined the thread.
This could not be available.
archived: :class:`bool`
Whether the thread is archived.
locked: :class:`bool`
@ -115,7 +112,7 @@ class Thread(Messageable, Hashable):
Whether non-moderators can add other non-moderators to this thread.
This is always ``True`` for public threads.
archiver_id: Optional[:class:`int`]
The user's ID that archived this thread.
The ID of the user that archived this thread.
auto_archive_duration: :class:`int`
The duration in minutes until the thread is automatically archived due to inactivity.
Usually a value of 60, 1440, 4320 and 10080.
@ -136,13 +133,13 @@ class Thread(Messageable, Hashable):
'message_count',
'member_count',
'slowmode_delay',
'me',
'locked',
'archived',
'invitable',
'archiver_id',
'auto_archive_duration',
'archive_timestamp',
'member_ids',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
@ -173,12 +170,13 @@ class Thread(Messageable, Hashable):
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self.member_ids = data['member_ids_preview']
self._unroll_metadata(data['thread_metadata'])
try:
member = data['member']
except KeyError:
self.me = None
pass
else:
self.me = ThreadMember(self, member)
@ -191,17 +189,26 @@ class Thread(Messageable, Hashable):
self.invitable = data.get('invitable', True)
def _update(self, data):
try:
self.name = data['name']
except KeyError:
pass
old = copy.copy(self)
self.slowmode_delay = data.get('rate_limit_per_user', 0)
try:
self._unroll_metadata(data['thread_metadata'])
except KeyError:
pass
if (meta := data.get('thread_metadata')) is not None:
self._unroll_metadata(meta)
if (name := data.get('name')) is not None:
self.name = name
if (last_message_id := _get_as_snowflake(data, 'last_message_id')) is not None:
self.last_message_id = last_message_id
if (message_count := data.get('message_count')) is not None:
self.message_count = message_count
if (member_count := data.get('member_count')) is not None:
self.member_count = member_count
if (member_ids := data.get('member_ids_preview')) is not None:
self.member_ids = member_ids
attrs = [x for x in self.__slots__ if not any(y in x for y in ('member', 'guild', 'state', 'count'))]
if any(getattr(self, attr) != getattr(old, attr) for attr in attrs):
return old
@property
def type(self) -> ChannelType:
@ -218,6 +225,11 @@ class Thread(Messageable, Hashable):
"""Optional[:class:`Member`]: The member this thread belongs to."""
return self.guild.get_member(self.owner_id)
@property
def archiver(self) -> Optional[Member]:
"""Optional[:class:`Member`]: The member that archived this thread."""
return self.guild.get_member(self.archiver_id)
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the thread."""
@ -227,9 +239,8 @@ class Thread(Messageable, Hashable):
def members(self) -> List[ThreadMember]:
"""List[:class:`ThreadMember`]: A list of thread members in this thread.
This requires :attr:`Intents.members` to be properly filled. Most of the time however,
this data is not provided by the gateway and a call to :meth:`fetch_members` is
needed.
Initial members are not provided by Discord. You must call :func:`fetch_members`
or have thread subscribing enabled.
"""
return list(self._members.values())
@ -294,6 +305,19 @@ class Thread(Messageable, Hashable):
raise ClientException('Parent channel not found')
return parent.category_id
@property
def me(self) -> Optional[ThreadMember]:
"""Optional[:class:`ThreadMember`]: A thread member representing yourself, if you've joined the thread.
This might not be available.
"""
self_id = self._state.user.id
return self._members.get(self_id)
@me.setter
def me(self, member) -> None:
self._members[member.id] = member
def is_private(self) -> bool:
""":class:`bool`: Whether the thread is a private thread.
@ -357,17 +381,13 @@ class Thread(Messageable, Hashable):
Deletes a list of messages. This is similar to :meth:`Message.delete`
except it bulk deletes multiple messages.
As a special case, if the number of messages is 0, then nothing
is done. If the number of messages is 1 then single message
delete is done. If it's more than two, then bulk delete is used.
You cannot bulk delete more than 100 messages or messages that
are older than 14 days old.
You must have the :attr:`~Permissions.manage_messages` permission to
use this.
use this (unless they're your own).
Usable only by bot accounts.
.. note::
Users do not have access to the message bulk-delete endpoint.
Since messages are just iterated over and deleted one-by-one,
it's easy to get ratelimited using this method.
Parameters
-----------
@ -376,13 +396,8 @@ class Thread(Messageable, Hashable):
Raises
------
ClientException
The number of messages to delete was more than 100.
Forbidden
You do not have proper permissions to delete the messages or
you're not using a bot account.
NotFound
If single delete, then the message was already deleted.
You do not have proper permissions to delete the messages.
HTTPException
Deleting the messages failed.
"""
@ -390,18 +405,9 @@ class Thread(Messageable, Hashable):
messages = list(messages)
if len(messages) == 0:
return # do nothing
if len(messages) == 1:
message_id = messages[0].id
await self._state.http.delete_message(self.id, message_id)
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
return # Do nothing
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
await self._state._delete_messages(self.id, messages)
async def purge(
self,
@ -412,7 +418,6 @@ class Thread(Messageable, Hashable):
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = False,
bulk: bool = True,
) -> List[Message]:
"""|coro|
@ -420,10 +425,8 @@ class Thread(Messageable, Hashable):
``check``. If a ``check`` is not provided then all messages are deleted
without discrimination.
You must have the :attr:`~Permissions.manage_messages` permission to
delete messages even if they are your own (unless you are a user
account). The :attr:`~Permissions.read_message_history` permission is
also needed to retrieve message history.
The :attr:`~Permissions.read_message_history` permission is needed to
retrieve message history.
Examples
---------
@ -433,8 +436,8 @@ class Thread(Messageable, Hashable):
def is_me(m):
return m.author == client.user
deleted = await thread.purge(limit=100, check=is_me)
await thread.send(f'Deleted {len(deleted)} message(s)')
deleted = await channel.purge(limit=100, check=is_me)
await channel.send(f'Deleted {len(deleted)} message(s)')
Parameters
-----------
@ -452,10 +455,6 @@ class Thread(Messageable, Hashable):
Same as ``around`` in :meth:`history`.
oldest_first: Optional[:class:`bool`]
Same as ``oldest_first`` in :meth:`history`.
bulk: :class:`bool`
If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting
a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will
fall back to single delete if messages are older than two weeks.
Raises
-------
@ -469,54 +468,30 @@ class Thread(Messageable, Hashable):
List[:class:`.Message`]
The list of messages that were deleted.
"""
if check is MISSING:
check = lambda m: True
state = self._state
channel_id = self.id
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
ret: List[Message] = []
count = 0
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
async def _single_delete_strategy(messages: Iterable[Message]):
for m in messages:
await m.delete()
strategy = self.delete_messages if bulk else _single_delete_strategy
async for message in iterator:
if count == 100:
to_delete = ret[-100:]
await strategy(to_delete)
if count == 50:
to_delete = ret[-50:]
await state._delete_messages(channel_id, to_delete)
count = 0
await asyncio.sleep(1)
if not check(message):
continue
if message.id < minimum_time:
# older than 14 days old
if count == 1:
await ret[-1].delete()
elif count >= 2:
to_delete = ret[-count:]
await strategy(to_delete)
count = 0
strategy = _single_delete_strategy
count += 1
ret.append(message)
# SOme messages remaining to poll
if count >= 2:
# more than 2 messages -> bulk delete
to_delete = ret[-count:]
await strategy(to_delete)
elif count == 1:
# delete a single message
await ret[-1].delete()
# Some messages remaining to poll
to_delete = ret[-count:]
await state._delete_messages(channel_id, to_delete)
return ret
@ -666,23 +641,32 @@ class Thread(Messageable, Hashable):
async def fetch_members(self) -> List[ThreadMember]:
"""|coro|
Retrieves all :class:`ThreadMember` that are in this thread.
This requires :attr:`Intents.members` to get information about members
other than yourself.
Retrieves all :class:`ThreadMember` that are in this thread,
along with their respective :class:`Member`.
Raises
-------
HTTPException
Retrieving the members failed.
InvalidData
Discord didn't respond with the members.
Returns
--------
List[:class:`ThreadMember`]
All thread members in the thread.
"""
members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members]
state = self._state
await state.ws.request_lazy_guild(self.parent.guild.id, thread_member_lists=[self.id]) # type: ignore
future = state.ws.wait_for('THREAD_MEMBER_LIST_UPDATE', lambda d: int(d['thread_id']) == self.id)
try:
data = await asyncio.wait_for(future, timeout=30)
except asyncio.TimeoutError as exc:
raise InvalidData('Failed to retrieve members') from exc
members = [ThreadMember(self, {'member': member}) for member in data['members']]
for m in members:
self._add_member(m)
return self.members # Includes correct self.me
async def delete(self):
"""|coro|
@ -718,13 +702,13 @@ class Thread(Messageable, Hashable):
:class:`PartialMessage`
The partial message.
"""
from .message import PartialMessage
return PartialMessage(channel=self, id=message_id)
def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member
if member.id != self._state.self_id:
self._members[member.id] = member
def _pop_member(self, member_id: int) -> Optional[ThreadMember]:
return self._members.pop(member_id, None)
@ -759,8 +743,12 @@ class ThreadMember(Hashable):
The thread member's ID.
thread_id: :class:`int`
The thread's ID.
joined_at: :class:`datetime.datetime`
joined_at: Optional[:class:`datetime.datetime`]
The time the member joined the thread in UTC.
Only reliably available for yourself or members joined while the user is connected to the gateway.
flags: :class:`int`
The thread member's flags. Will be its own class in the future.
Only reliably available for yourself or members joined while the user is connected to the gateway.
"""
__slots__ = (
@ -781,19 +769,32 @@ class ThreadMember(Hashable):
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload):
state = self._state
try:
self.id = int(data['user_id'])
except KeyError:
assert self._state.self_id is not None
self.id = self._state.self_id
assert state.self_id is not None
self.id = state.self_id
try:
self.thread_id = int(data['id'])
except KeyError:
self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp'])
self.flags = data['flags']
self.joined_at = parse_time(data.get('join_timestamp'))
self.flags = data.get('flags')
if (data := data.get('member')) is not None:
guild = self.parent.parent.guild # type: ignore
mdata = data['member']
mdata['guild_id'] = guild.id
self.id = user_id = int(data['user_id'])
mdata['presence'] = data.get('presence')
if guild.get_member(user_id) is not None:
state.parse_guild_member_update(mdata)
else:
state.parse_guild_member_add(mdata)
@property
def thread(self) -> Thread:
@ -802,7 +803,7 @@ class ThreadMember(Hashable):
@property
def member(self) -> Optional[Member]:
"""Optional[:class:`Member`]: The member this member represents. If the member
"""Optional[:class:`Member`]: The member this :class:`ThreadMember` represents. If the member
is not cached then this will be ``None``.
"""
return self.parent.guild.get_member(self.id)
return self.parent.parent.guild.get_member(self.id) # type: ignore

Loading…
Cancel
Save