diff --git a/discord/message.py b/discord/message.py index 288475a41..345dc5dca 100644 --- a/discord/message.py +++ b/discord/message.py @@ -22,10 +22,14 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio import datetime import re import io +from os import PathLike +from typing import TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar from . import utils from .reaction import Reaction @@ -42,6 +46,26 @@ from .guild import Guild from .mixins import Hashable from .sticker import Sticker +if TYPE_CHECKING: + from .types.message import ( + Message as MessagePayload, + Attachment as AttachmentPayload, + MessageReference as MessageReferencePayload, + MessageApplication as MessageApplicationPayload, + MessageActivity as MessageActivityPayload, + Reaction as ReactionPayload, + ) + + from .types.member import Member as MemberPayload + from .types.user import User as UserPayload + from .types.embed import Embed as EmbedPayload + from .abc import Snowflake + from .abc import GuildChannel, PrivateChannel, Messageable + from .state import ConnectionState + from .channel import TextChannel, GroupChannel, DMChannel + + EmojiInputType = Union[Emoji, PartialEmoji, str] + __all__ = ( 'Attachment', 'Message', @@ -116,7 +140,7 @@ class Attachment(Hashable): __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') - def __init__(self, *, data, state): + def __init__(self, *, data: AttachmentPayload, state: ConnectionState): self.id = int(data['id']) self.size = data['size'] self.height = data.get('height') @@ -127,17 +151,17 @@ class Attachment(Hashable): self._http = state.http self.content_type = data.get('content_type') - def is_spoiler(self): + def is_spoiler(self) -> bool: """:class:`bool`: Whether this attachment contains a spoiler.""" return self.filename.startswith('SPOILER_') - def __repr__(self): + def __repr__(self) -> str: return f'' - def __str__(self): + def __str__(self) -> str: return self.url or '' - async def save(self, fp, *, seek_begin=True, use_cached=False): + async def save(self, fp: Union[io.BufferedIOBase, PathLike], *, seek_begin: bool = True, use_cached: bool = False) -> int: """|coro| Saves this attachment into a file-like object. @@ -181,7 +205,7 @@ class Attachment(Hashable): with open(fp, 'wb') as f: return f.write(data) - async def read(self, *, use_cached=False): + async def read(self, *, use_cached: bool = False) -> bytes: """|coro| Retrieves the content of this attachment as a :class:`bytes` object. @@ -216,7 +240,7 @@ class Attachment(Hashable): data = await self._http.get_from_cdn(url) return data - async def to_file(self, *, use_cached=False, spoiler=False): + async def to_file(self, *, use_cached: bool = False, spoiler: bool = False) -> File: """|coro| Converts the attachment into a :class:`File` suitable for sending via @@ -258,8 +282,8 @@ class Attachment(Hashable): data = await self.read(use_cached=use_cached) return File(io.BytesIO(data), filename=self.filename, spoiler=spoiler) - def to_dict(self): - result = { + def to_dict(self) -> AttachmentPayload: + result: AttachmentPayload = { 'filename': self.filename, 'id': self.id, 'proxy_url': self.proxy_url, @@ -287,24 +311,24 @@ class DeletedReferencedMessage: __slots__ = ('_parent') - def __init__(self, parent): + def __init__(self, parent: MessageReference): self._parent = parent - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def id(self): + def id(self) -> int: """:class:`int`: The message ID of the deleted referenced message.""" return self._parent.message_id @property - def channel_id(self): + def channel_id(self) -> int: """:class:`int`: The channel ID of the deleted referenced message.""" return self._parent.channel_id @property - def guild_id(self): + def guild_id(self) -> Optional[int]: """Optional[:class:`int`]: The guild ID of the deleted referenced message.""" return self._parent.guild_id @@ -345,16 +369,16 @@ class MessageReference: __slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state') - def __init__(self, *, message_id, channel_id, guild_id=None, fail_if_not_exists=True): - self._state = None - self.resolved = None + def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True): + self._state: Optional[ConnectionState] = None + self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None self.message_id = message_id self.channel_id = channel_id self.guild_id = guild_id self.fail_if_not_exists = fail_if_not_exists @classmethod - def with_state(cls, state, data): + def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> MessageReference: self = cls.__new__(cls) self.message_id = utils._get_as_snowflake(data, 'message_id') self.channel_id = int(data.pop('channel_id')) @@ -365,7 +389,7 @@ class MessageReference: return self @classmethod - def from_message(cls, message, *, fail_if_not_exists=True): + def from_message(cls, message: Message, *, fail_if_not_exists: bool = True): """Creates a :class:`MessageReference` from an existing :class:`~discord.Message`. .. versionadded:: 1.6 @@ -390,12 +414,12 @@ class MessageReference: return self @property - def cached_message(self): + def cached_message(self) -> Optional[Message]: """Optional[:class:`~discord.Message`]: The cached message, if found in the internal message cache.""" - return self._state._get_message(self.message_id) + return self._state and self._state._get_message(self.message_id) @property - def jump_url(self): + def jump_url(self) -> str: """:class:`str`: Returns a URL that allows the client to jump to the referenced message. .. versionadded:: 1.7 @@ -403,11 +427,11 @@ class MessageReference: guild_id = self.guild_id if self.guild_id is not None else '@me' return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}' - def __repr__(self): + def __repr__(self) -> str: return f'' - def to_dict(self): - result = {'message_id': self.message_id} if self.message_id is not None else {} + def to_dict(self) -> MessageReferencePayload: + result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} result['channel_id'] = self.channel_id if self.guild_id is not None: result['guild_id'] = self.guild_id @@ -460,17 +484,17 @@ class Message(Hashable): type: :class:`MessageType` The type of message. In most cases this should not be checked, but it is helpful in cases where it might be a system message for :attr:`system_content`. - author: :class:`abc.User` + author: Union[:class:`Member`, :class:`abc.User`] A :class:`Member` that sent the message. If :attr:`channel` is a private channel or the user has the left the guild, then it is a :class:`User` instead. content: :class:`str` The actual contents of the message. - nonce + nonce: Union[:class:`str`, :class:`int`] The value used by the discord guild and the client to verify that the message is successfully sent. This is not stored long term within Discord's servers and is only used ephemerally. embeds: List[:class:`Embed`] A list of embeds the message has. - channel: Union[:class:`abc.Messageable`] + channel: Union[:class:`TextChannel`, :class:`DMChannel`, :class:`GroupChannel`] The :class:`TextChannel` that the message was sent from. Could be a :class:`DMChannel` or :class:`GroupChannel` if it's a private message. reference: Optional[:class:`~discord.MessageReference`] @@ -552,7 +576,10 @@ class Message(Hashable): '_cs_system_content', '_cs_guild', '_state', 'reactions', 'reference', 'application', 'activity', 'stickers') - def __init__(self, *, state, channel, data): + _HANDLERS: ClassVar[List[Tuple[str, Callable[..., None]]]] + _CACHED_SLOTS: ClassVar[List[str]] + + def __init__(self, *, state: ConnectionState, channel: Union[TextChannel, DMChannel, GroupChannel], data: MessagePayload): self._state = state self.id = int(data['id']) self.webhook_id = utils._get_as_snowflake(data, 'webhook_id') @@ -600,10 +627,10 @@ class Message(Hashable): except KeyError: continue - def __repr__(self): + def __repr__(self) -> str: return f'' - def _try_patch(self, data, key, transform=None): + def _try_patch(self, data, key, transform=None) -> None: try: value = data[key] except KeyError: @@ -614,7 +641,7 @@ class Message(Hashable): else: setattr(self, key, transform(value)) - def _add_reaction(self, data, emoji, user_id): + def _add_reaction(self, data, emoji, user_id) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) is_me = data['me'] = user_id == self._state.self_id @@ -628,7 +655,7 @@ class Message(Hashable): return reaction - def _remove_reaction(self, data, emoji, user_id): + def _remove_reaction(self, data: ReactionPayload, emoji: EmojiInputType, user_id: int) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) if reaction is None: @@ -647,7 +674,7 @@ class Message(Hashable): return reaction - def _clear_emoji(self, emoji): + def _clear_emoji(self, emoji) -> Optional[Reaction]: to_check = str(emoji) for index, reaction in enumerate(self.reactions): if str(reaction.emoji) == to_check: @@ -679,50 +706,50 @@ class Message(Hashable): except AttributeError: pass - def _handle_edited_timestamp(self, value): + def _handle_edited_timestamp(self, value: str) -> None: self._edited_timestamp = utils.parse_time(value) - def _handle_pinned(self, value): + def _handle_pinned(self, value: int) -> None: self.pinned = value - def _handle_flags(self, value): + def _handle_flags(self, value: int) -> None: self.flags = MessageFlags._from_value(value) - def _handle_application(self, value): + def _handle_application(self, value: MessageApplicationPayload) -> None: self.application = value - def _handle_activity(self, value): + def _handle_activity(self, value: MessageActivityPayload) -> None: self.activity = value - def _handle_mention_everyone(self, value): + def _handle_mention_everyone(self, value: bool) -> None: self.mention_everyone = value - def _handle_tts(self, value): + def _handle_tts(self, value: bool) -> None: self.tts = value - def _handle_type(self, value): + def _handle_type(self, value: int) -> None: self.type = try_enum(MessageType, value) - def _handle_content(self, value): + def _handle_content(self, value: str) -> None: self.content = value - def _handle_attachments(self, value): + def _handle_attachments(self, value: List[AttachmentPayload]) -> None: self.attachments = [Attachment(data=a, state=self._state) for a in value] - def _handle_embeds(self, value): + def _handle_embeds(self, value: List[EmbedPayload]) -> None: self.embeds = [Embed.from_dict(data) for data in value] - def _handle_nonce(self, value): + def _handle_nonce(self, value: Union[str, int]) -> None: self.nonce = value - def _handle_author(self, author): + def _handle_author(self, author: UserPayload) -> None: self.author = self._state.store_user(author) if isinstance(self.guild, Guild): found = self.guild.get_member(self.author.id) if found is not None: self.author = found - def _handle_member(self, member): + def _handle_member(self, member: MemberPayload) -> None: # The gateway now gives us full Member objects sometimes with the following keys # deaf, mute, joined_at, roles # For the sake of performance I'm going to assume that the only @@ -732,13 +759,13 @@ class Message(Hashable): author = self.author try: # Update member reference - author._update_from_message(member) + author._update_from_message(member) # type: ignore except AttributeError: # It's a user here # TODO: consider adding to cache here self.author = Member._from_message(message=self, data=member) - def _handle_mentions(self, mentions): + def _handle_mentions(self, mentions: List[UserPayload]) -> None: self.mentions = r = [] guild = self.guild state = self._state @@ -754,7 +781,7 @@ class Message(Hashable): else: r.append(Member._try_upgrade(data=mention, guild=guild, state=state)) - def _handle_mention_roles(self, role_mentions): + def _handle_mention_roles(self, role_mentions: List[int]) -> None: self.role_mentions = [] if isinstance(self.guild, Guild): for role_id in map(int, role_mentions): @@ -762,21 +789,21 @@ class Message(Hashable): if role is not None: self.role_mentions.append(role) - def _rebind_channel_reference(self, new_channel): + def _rebind_channel_reference(self, new_channel: Union[TextChannel, DMChannel, GroupChannel]) -> None: self.channel = new_channel try: - del self._cs_guild + del self._cs_guild # type: ignore except AttributeError: pass @utils.cached_slot_property('_cs_guild') - def guild(self): + def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild that the message belongs to, if applicable.""" return getattr(self.channel, 'guild', None) @utils.cached_slot_property('_cs_raw_mentions') - def raw_mentions(self): + def raw_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of user IDs matched with the syntax of ``<@user_id>`` in the message content. @@ -786,28 +813,28 @@ class Message(Hashable): return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)] @utils.cached_slot_property('_cs_raw_channel_mentions') - def raw_channel_mentions(self): + def raw_channel_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of channel IDs matched with the syntax of ``<#channel_id>`` in the message content. """ return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)] @utils.cached_slot_property('_cs_raw_role_mentions') - def raw_role_mentions(self): + def raw_role_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of role IDs matched with the syntax of ``<@&role_id>`` in the message content. """ return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)] @utils.cached_slot_property('_cs_channel_mentions') - def channel_mentions(self): + def channel_mentions(self) -> List[GuildChannel]: if self.guild is None: return [] it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) return utils._unique(it) @utils.cached_slot_property('_cs_clean_content') - def clean_content(self): + def clean_content(self) -> str: """:class:`str`: A property that returns the content in a "cleaned up" manner. This basically means that mentions are transformed into the way the client shows it. e.g. ``<#id>`` will transform @@ -857,22 +884,22 @@ class Message(Hashable): return escape_mentions(result) @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: The message's creation time in UTC.""" return utils.snowflake_time(self.id) @property - def edited_at(self): + def edited_at(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: An aware UTC datetime object containing the edited time of the message.""" return self._edited_timestamp @property - def jump_url(self): + def jump_url(self) -> str: """:class:`str`: Returns a URL that allows the client to jump to this message.""" guild_id = getattr(self.guild, 'id', '@me') return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}' - def is_system(self): + def is_system(self) -> bool: """:class:`bool`: Whether the message is a system message. .. versionadded:: 1.3 @@ -943,7 +970,7 @@ class Message(Hashable): return f'{self.author.name} has added {self.content} to this channel' if self.type is MessageType.guild_stream: - return f'{self.author.name} is live! Now streaming {self.author.activity.name}' + return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore if self.type is MessageType.guild_discovery_disqualified: return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.' @@ -960,7 +987,7 @@ class Message(Hashable): if self.type is MessageType.guild_invite_reminder: return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!' - async def delete(self, *, delay=None): + async def delete(self, *, delay: Optional[float] = None) -> None: """|coro| Deletes the message. @@ -988,18 +1015,18 @@ class Message(Hashable): Deleting the message failed. """ if delay is not None: - async def delete(): + async def delete(delay: float): await asyncio.sleep(delay) try: await self._state.http.delete_message(self.channel.id, self.id) except HTTPException: pass - asyncio.create_task(delete()) + asyncio.create_task(delete(delay)) else: await self._state.http.delete_message(self.channel.id, self.id) - async def edit(self, **fields): + async def edit(self, **fields: Any) -> None: """|coro| Edits the message. @@ -1102,7 +1129,7 @@ class Message(Hashable): if delete_after is not None: await self.delete(delay=delete_after) - async def publish(self): + async def publish(self) -> None: """|coro| Publishes this message to your announcement channel. @@ -1120,7 +1147,7 @@ class Message(Hashable): await self._state.http.publish_message(self.channel.id, self.id) - async def pin(self, *, reason=None): + async def pin(self, *, reason: Optional[str] = None) -> None: """|coro| Pins the message. @@ -1149,7 +1176,7 @@ class Message(Hashable): await self._state.http.pin_message(self.channel.id, self.id, reason=reason) self.pinned = True - async def unpin(self, *, reason=None): + async def unpin(self, *, reason: Optional[str] = None) -> None: """|coro| Unpins the message. @@ -1177,7 +1204,7 @@ class Message(Hashable): await self._state.http.unpin_message(self.channel.id, self.id, reason=reason) self.pinned = False - async def add_reaction(self, emoji): + async def add_reaction(self, emoji: EmojiInputType) -> None: """|coro| Add a reaction to the message. @@ -1208,7 +1235,7 @@ class Message(Hashable): emoji = convert_emoji_reaction(emoji) await self._state.http.add_reaction(self.channel.id, self.id, emoji) - async def remove_reaction(self, emoji, member): + async def remove_reaction(self, emoji: Union[EmojiInputType, Reaction], member: Snowflake) -> None: """|coro| Remove a reaction by the member from the message. @@ -1247,7 +1274,7 @@ class Message(Hashable): else: await self._state.http.remove_reaction(self.channel.id, self.id, emoji, member.id) - async def clear_reaction(self, emoji): + async def clear_reaction(self, emoji: Union[EmojiInputType, Reaction]) -> None: """|coro| Clears a specific reaction from the message. @@ -1278,7 +1305,7 @@ class Message(Hashable): emoji = convert_emoji_reaction(emoji) await self._state.http.clear_single_reaction(self.channel.id, self.id, emoji) - async def clear_reactions(self): + async def clear_reactions(self) -> None: """|coro| Removes all the reactions from the message. @@ -1294,7 +1321,7 @@ class Message(Hashable): """ await self._state.http.clear_reactions(self.channel.id, self.id) - async def reply(self, content=None, **kwargs): + async def reply(self, content: Optional[str] = None, **kwargs) -> Message: """|coro| A shortcut method to :meth:`.abc.Messageable.send` to reply to the @@ -1320,7 +1347,7 @@ class Message(Hashable): return await self.channel.send(content, reference=self, **kwargs) - def to_reference(self, *, fail_if_not_exists=True): + def to_reference(self, *, fail_if_not_exists: bool = True) -> MessageReference: """Creates a :class:`~discord.MessageReference` from the current message. .. versionadded:: 1.6 @@ -1341,8 +1368,8 @@ class Message(Hashable): return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists) - def to_message_reference_dict(self): - data = { + def to_message_reference_dict(self) -> MessageReferencePayload: + data: MessageReferencePayload = { 'message_id': self.id, 'channel_id': self.channel.id, } @@ -1411,7 +1438,7 @@ class PartialMessage(Hashable): 'to_message_reference_dict', ) - def __init__(self, *, channel, id): + def __init__(self, *, channel: Union[GuildChannel, PrivateChannel], id: int): if channel.type not in (ChannelType.text, ChannelType.news, ChannelType.private): raise TypeError(f'Expected TextChannel or DMChannel not {type(channel)!r}') @@ -1419,29 +1446,29 @@ class PartialMessage(Hashable): self._state = channel._state self.id = id - def _update(self, data): + def _update(self, data) -> None: # This is used for duck typing purposes. # Just do nothing with the data. pass # Also needed for duck typing purposes # n.b. not exposed - pinned = property(None, lambda x, y: ...) + pinned = property(None, lambda x, y: None) - def __repr__(self): + def __repr__(self) -> str: return f'' @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: The partial message's creation time in UTC.""" return utils.snowflake_time(self.id) @utils.cached_slot_property('_cs_guild') - def guild(self): + def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable.""" return getattr(self.channel, 'guild', None) - async def fetch(self): + async def fetch(self) -> Message: """|coro| Fetches the partial message to a full :class:`Message`. @@ -1464,7 +1491,7 @@ class PartialMessage(Hashable): data = await self._state.http.get_message(self.channel.id, self.id) return self._state.create_message(channel=self.channel, data=data) - async def edit(self, **fields): + async def edit(self, **fields: Any) -> Optional[Message]: """|coro| Edits the message. @@ -1558,7 +1585,7 @@ class PartialMessage(Hashable): data = await self._state.http.edit_message(self.channel.id, self.id, **fields) if delete_after is not None: - await self.delete(delay=delete_after) + await self.delete(delay=delete_after) # type: ignore if fields: - return self._state.create_message(channel=self.channel, data=data) + return self._state.create_message(channel=self.channel, data=data) # type: ignore diff --git a/discord/types/message.py b/discord/types/message.py index 86ac82df6..1aa8259b4 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -51,6 +51,7 @@ class _AttachmentOptional(TypedDict, total=False): height: Optional[int] width: Optional[int] content_type: str + spoiler: bool class Attachment(_AttachmentOptional):