Browse Source

[commands] Fix (Partial)MessageConverter to work with thread messages

pull/7197/head
Nadir Chowdhury 4 years ago
committed by GitHub
parent
commit
e2624b9a31
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      discord/abc.py
  2. 16
      discord/ext/commands/converter.py
  3. 17
      discord/message.py
  4. 27
      discord/threads.py

3
discord/abc.py

@ -86,7 +86,8 @@ if TYPE_CHECKING:
OverwriteType, OverwriteType,
) )
MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel] PartialMessageableChannel = Union[TextChannel, Thread, DMChannel]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime] SnowflakeTime = Union["Snowflake", datetime]
MISSING = utils.MISSING MISSING = utils.MISSING

16
discord/ext/commands/converter.py

@ -48,6 +48,7 @@ from .errors import *
if TYPE_CHECKING: if TYPE_CHECKING:
from .context import Context from .context import Context
from discord.message import PartialMessageableChannel
__all__ = ( __all__ = (
@ -349,11 +350,11 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return guild_id, message_id, channel_id return guild_id, message_id, channel_id
@staticmethod @staticmethod
def _resolve_channel(ctx, guild_id, channel_id): def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is not None: if guild_id is not None:
guild = ctx.bot.get_guild(guild_id) guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None: if guild is not None and channel_id is not None:
return guild.get_channel(channel_id) return guild._resolve_channel(channel_id) # type: ignore
else: else:
return None return None
else: else:
@ -470,6 +471,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result return result
class TextChannelConverter(IDConverter[discord.TextChannel]): class TextChannelConverter(IDConverter[discord.TextChannel]):
"""Converts to a :class:`~discord.TextChannel`. """Converts to a :class:`~discord.TextChannel`.
@ -567,6 +569,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
class ThreadConverter(IDConverter[discord.Thread]): class ThreadConverter(IDConverter[discord.Thread]):
"""Coverts to a :class:`~discord.Thread`. """Coverts to a :class:`~discord.Thread`.
@ -584,6 +587,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
async def convert(self, ctx: Context, argument: str) -> discord.Thread: async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
class ColourConverter(Converter[discord.Colour]): class ColourConverter(Converter[discord.Colour]):
"""Converts to a :class:`~discord.Colour`. """Converts to a :class:`~discord.Colour`.
@ -844,7 +848,7 @@ class clean_content(Converter[str]):
fix_channel_mentions: bool = False, fix_channel_mentions: bool = False,
use_nicknames: bool = True, use_nicknames: bool = True,
escape_markdown: bool = False, escape_markdown: bool = False,
remove_markdown: bool = False remove_markdown: bool = False,
) -> None: ) -> None:
self.fix_channel_mentions = fix_channel_mentions self.fix_channel_mentions = fix_channel_mentions
self.use_nicknames = use_nicknames self.use_nicknames = use_nicknames
@ -855,6 +859,7 @@ class clean_content(Converter[str]):
msg = ctx.message msg = ctx.message
if ctx.guild: if ctx.guild:
def resolve_member(id: int) -> str: def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
@ -862,7 +867,9 @@ class clean_content(Converter[str]):
def resolve_role(id: int) -> str: def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f'@{r.name}' if r else '@deleted-role' return f'@{r.name}' if r else '@deleted-role'
else: else:
def resolve_member(id: int) -> str: def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user' return f'@{m.name}' if m else '@deleted-user'
@ -871,10 +878,13 @@ class clean_content(Converter[str]):
return '@deleted-role' return '@deleted-role'
if self.fix_channel_mentions and ctx.guild: if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str: def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id) c = ctx.guild.get_channel(id)
return f'#{c.name}' if c else '#deleted-channel' return f'#{c.name}' if c else '#deleted-channel'
else: else:
def resolve_channel(id: int) -> str: def resolve_channel(id: int) -> str:
return f'<#{id}>' return f'<#{id}>'

17
discord/message.py

@ -67,7 +67,7 @@ if TYPE_CHECKING:
from .types.user import User as UserPayload from .types.user import User as UserPayload
from .types.embed import Embed as EmbedPayload from .types.embed import Embed as EmbedPayload
from .abc import Snowflake from .abc import Snowflake
from .abc import GuildChannel from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel
from .components import Component from .components import Component
from .state import ConnectionState from .state import ConnectionState
from .channel import TextChannel, GroupChannel, DMChannel from .channel import TextChannel, GroupChannel, DMChannel
@ -657,7 +657,7 @@ class Message(Hashable):
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']]
self.application: Optional[MessageApplicationPayload] = data.get('application') self.application: Optional[MessageApplicationPayload] = data.get('application')
self.activity: Optional[MessageActivityPayload] = data.get('activity') self.activity: Optional[MessageActivityPayload] = data.get('activity')
self.channel: Union[TextChannel, Thread, DMChannel, GroupChannel] = channel self.channel: MessageableChannel = channel
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp'])
self.type: MessageType = try_enum(MessageType, data['type']) self.type: MessageType = try_enum(MessageType, data['type'])
self.pinned: bool = data['pinned'] self.pinned: bool = data['pinned']
@ -1557,8 +1557,11 @@ class PartialMessage(Hashable):
a message and channel ID are present. a message and channel ID are present.
There are two ways to construct this class. The first one is through There are two ways to construct this class. The first one is through
the constructor itself, and the second is via the constructor itself, and the second is via the following:
:meth:`TextChannel.get_partial_message` or :meth:`DMChannel.get_partial_message`.
- :meth:`TextChannel.get_partial_message`
- :meth:`Thread.get_partial_message`
- :meth:`DMChannel.get_partial_message`
Note that this class is trimmed down and has no rich attributes. Note that this class is trimmed down and has no rich attributes.
@ -1580,7 +1583,7 @@ class PartialMessage(Hashable):
Attributes Attributes
----------- -----------
channel: Union[:class:`TextChannel`, :class:`DMChannel`] channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]
The channel associated with this partial message. The channel associated with this partial message.
id: :class:`int` id: :class:`int`
The message ID. The message ID.
@ -1601,11 +1604,11 @@ class PartialMessage(Hashable):
to_reference = Message.to_reference to_reference = Message.to_reference
to_message_reference_dict = Message.to_message_reference_dict to_message_reference_dict = Message.to_message_reference_dict
def __init__(self, *, channel: Union[TextChannel, DMChannel], id: int): def __init__(self, *, channel: PartialMessageableChannel, id: int):
if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private): if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private):
raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}') raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}')
self.channel: Union[TextChannel, DMChannel] = channel self.channel: PartialMessageableChannel = channel
self._state: ConnectionState = channel._state self._state: ConnectionState = channel._state
self.id: int = id self.id: int = id

27
discord/threads.py

@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
import time import time
import asyncio import asyncio
@ -48,7 +49,7 @@ if TYPE_CHECKING:
from .guild import Guild from .guild import Guild
from .channel import TextChannel from .channel import TextChannel
from .member import Member from .member import Member
from .message import Message from .message import Message, PartialMessage
from .abc import Snowflake, SnowflakeTime from .abc import Snowflake, SnowflakeTime
from .role import Role from .role import Role
from .permissions import Permissions from .permissions import Permissions
@ -191,6 +192,7 @@ class Thread(Messageable, Hashable):
self._unroll_metadata(data['thread_metadata']) self._unroll_metadata(data['thread_metadata'])
except KeyError: except KeyError:
pass pass
@property @property
def type(self) -> ChannelType: def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type.""" """:class:`ChannelType`: The channel's Discord type."""
@ -626,6 +628,29 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.delete_channel(self.id) await self._state.http.delete_channel(self.id)
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
This is useful if you want to work with a message and only have its ID without
doing an unnecessary API call.
.. versionadded:: 2.0
Parameters
------------
message_id: :class:`int`
The message ID to create a partial message for.
Returns
---------
:class:`PartialMessage`
The partial message.
"""
from .message import PartialMessage
return PartialMessage(channel=self, id=message_id)
def _add_member(self, member: ThreadMember) -> None: def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member self._members[member.id] = member

Loading…
Cancel
Save