diff --git a/discord/abc.py b/discord/abc.py index 6efb2cd7c..7bafeec6c 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -86,7 +86,8 @@ if TYPE_CHECKING: OverwriteType, ) - MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel] + PartialMessageableChannel = Union[TextChannel, Thread, DMChannel] + MessageableChannel = Union[PartialMessageableChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] MISSING = utils.MISSING diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 339e10730..dd7d6577e 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -48,6 +48,7 @@ from .errors import * if TYPE_CHECKING: from .context import Context + from discord.message import PartialMessageableChannel __all__ = ( @@ -349,11 +350,11 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): return guild_id, message_id, channel_id @staticmethod - def _resolve_channel(ctx, guild_id, channel_id): + def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: if guild_id is not None: guild = ctx.bot.get_guild(guild_id) if guild is not None and channel_id is not None: - return guild.get_channel(channel_id) + return guild._resolve_channel(channel_id) # type: ignore else: return None else: @@ -470,6 +471,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): return result + class TextChannelConverter(IDConverter[discord.TextChannel]): """Converts to a :class:`~discord.TextChannel`. @@ -567,6 +569,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) + class ThreadConverter(IDConverter[discord.Thread]): """Coverts to a :class:`~discord.Thread`. @@ -584,6 +587,7 @@ class ThreadConverter(IDConverter[discord.Thread]): async def convert(self, ctx: Context, argument: str) -> discord.Thread: return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) + class ColourConverter(Converter[discord.Colour]): """Converts to a :class:`~discord.Colour`. @@ -844,7 +848,7 @@ class clean_content(Converter[str]): fix_channel_mentions: bool = False, use_nicknames: bool = True, escape_markdown: bool = False, - remove_markdown: bool = False + remove_markdown: bool = False, ) -> None: self.fix_channel_mentions = fix_channel_mentions self.use_nicknames = use_nicknames @@ -855,6 +859,7 @@ class clean_content(Converter[str]): msg = ctx.message if ctx.guild: + def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' @@ -862,7 +867,9 @@ class clean_content(Converter[str]): def resolve_role(id: int) -> str: r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) return f'@{r.name}' if r else '@deleted-role' + else: + def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) return f'@{m.name}' if m else '@deleted-user' @@ -871,10 +878,13 @@ class clean_content(Converter[str]): return '@deleted-role' if self.fix_channel_mentions and ctx.guild: + def resolve_channel(id: int) -> str: c = ctx.guild.get_channel(id) return f'#{c.name}' if c else '#deleted-channel' + else: + def resolve_channel(id: int) -> str: return f'<#{id}>' diff --git a/discord/message.py b/discord/message.py index b4604b383..ed2b2df94 100644 --- a/discord/message.py +++ b/discord/message.py @@ -67,7 +67,7 @@ if TYPE_CHECKING: from .types.user import User as UserPayload from .types.embed import Embed as EmbedPayload from .abc import Snowflake - from .abc import GuildChannel + from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel from .components import Component from .state import ConnectionState from .channel import TextChannel, GroupChannel, DMChannel @@ -657,7 +657,7 @@ class Message(Hashable): self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] self.application: Optional[MessageApplicationPayload] = data.get('application') self.activity: Optional[MessageActivityPayload] = data.get('activity') - self.channel: Union[TextChannel, Thread, DMChannel, GroupChannel] = channel + self.channel: MessageableChannel = channel self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) self.type: MessageType = try_enum(MessageType, data['type']) self.pinned: bool = data['pinned'] @@ -1557,8 +1557,11 @@ class PartialMessage(Hashable): a message and channel ID are present. There are two ways to construct this class. The first one is through - the constructor itself, and the second is via - :meth:`TextChannel.get_partial_message` or :meth:`DMChannel.get_partial_message`. + the constructor itself, and the second is via the following: + + - :meth:`TextChannel.get_partial_message` + - :meth:`Thread.get_partial_message` + - :meth:`DMChannel.get_partial_message` Note that this class is trimmed down and has no rich attributes. @@ -1580,7 +1583,7 @@ class PartialMessage(Hashable): Attributes ----------- - channel: Union[:class:`TextChannel`, :class:`DMChannel`] + channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`] The channel associated with this partial message. id: :class:`int` The message ID. @@ -1601,11 +1604,11 @@ class PartialMessage(Hashable): to_reference = Message.to_reference to_message_reference_dict = Message.to_message_reference_dict - def __init__(self, *, channel: Union[TextChannel, DMChannel], id: int): + def __init__(self, *, channel: PartialMessageableChannel, id: int): if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private): raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}') - self.channel: Union[TextChannel, DMChannel] = channel + self.channel: PartialMessageableChannel = channel self._state: ConnectionState = channel._state self.id: int = id diff --git a/discord/threads.py b/discord/threads.py index 24eda6512..daf7e5180 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations + from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING import time import asyncio @@ -48,7 +49,7 @@ if TYPE_CHECKING: from .guild import Guild from .channel import TextChannel from .member import Member - from .message import Message + from .message import Message, PartialMessage from .abc import Snowflake, SnowflakeTime from .role import Role from .permissions import Permissions @@ -191,6 +192,7 @@ class Thread(Messageable, Hashable): self._unroll_metadata(data['thread_metadata']) except KeyError: pass + @property def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" @@ -626,6 +628,29 @@ class Thread(Messageable, Hashable): """ await self._state.http.delete_channel(self.id) + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 2.0 + + Parameters + ------------ + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + --------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + def _add_member(self, member: ThreadMember) -> None: self._members[member.id] = member