From afbbc07e980cabd7ce52e3d4c67f16758e705a3e Mon Sep 17 00:00:00 2001 From: DA344 <108473820+DA-344@users.noreply.github.com> Date: Sun, 19 Jan 2025 11:09:05 +0100 Subject: [PATCH] Add support for poll result messages --- discord/enums.py | 1 + discord/message.py | 16 +++++++ discord/poll.py | 96 ++++++++++++++++++++++++++++++++++++++-- discord/state.py | 21 +++++++++ discord/types/embed.py | 2 +- discord/types/message.py | 1 + docs/api.rst | 4 ++ 7 files changed, 136 insertions(+), 5 deletions(-) diff --git a/discord/enums.py b/discord/enums.py index 4fe5f3ffa..ce772cc87 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -266,6 +266,7 @@ class MessageType(Enum): guild_incident_report_raid = 38 guild_incident_report_false_alarm = 39 purchase_notification = 44 + poll_result = 46 class SpeakingState(Enum): diff --git a/discord/message.py b/discord/message.py index 3d755e314..3016d2f29 100644 --- a/discord/message.py +++ b/discord/message.py @@ -2268,6 +2268,13 @@ class Message(PartialMessage, Hashable): # the channel will be the correct type here ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore + if self.type is MessageType.poll_result: + if isinstance(self.reference.resolved, self.__class__): + self._state._update_poll_results(self, self.reference.resolved) + else: + if self.reference.message_id: + self._state._update_poll_results(self, self.reference.message_id) + self.application: Optional[MessageApplication] = None try: application = data['application'] @@ -2634,6 +2641,7 @@ class Message(PartialMessage, Hashable): MessageType.chat_input_command, MessageType.context_menu_command, MessageType.thread_starter_message, + MessageType.poll_result, ) @utils.cached_slot_property('_cs_system_content') @@ -2810,6 +2818,14 @@ class Message(PartialMessage, Hashable): if guild_product_purchase is not None: return f'{self.author.name} has purchased {guild_product_purchase.product_name}!' + if self.type is MessageType.poll_result: + embed = self.embeds[0] # Will always have 1 embed + poll_title = utils.get( + embed.fields, + name='poll_question_text', + ) + return f'{self.author.display_name}\'s poll {poll_title.value} has closed.' # type: ignore + # Fallback for unknown message types return '' diff --git a/discord/poll.py b/discord/poll.py index 720f91245..767f8ffae 100644 --- a/discord/poll.py +++ b/discord/poll.py @@ -29,7 +29,7 @@ from typing import Optional, List, TYPE_CHECKING, Union, AsyncIterator, Dict import datetime -from .enums import PollLayoutType, try_enum +from .enums import PollLayoutType, try_enum, MessageType from . import utils from .emoji import PartialEmoji, Emoji from .user import User @@ -125,7 +125,16 @@ class PollAnswer: Whether the current user has voted to this answer or not. """ - __slots__ = ('media', 'id', '_state', '_message', '_vote_count', 'self_voted', '_poll') + __slots__ = ( + 'media', + 'id', + '_state', + '_message', + '_vote_count', + 'self_voted', + '_poll', + '_victor', + ) def __init__( self, @@ -141,6 +150,7 @@ class PollAnswer: self._vote_count: int = 0 self.self_voted: bool = False self._poll: Poll = poll + self._victor: bool = False def _handle_vote_event(self, added: bool, self_voted: bool) -> None: if added: @@ -210,6 +220,19 @@ class PollAnswer: 'poll_media': self.media.to_dict(), } + @property + def victor(self) -> bool: + """:class:`bool`: Whether the answer is the one that had the most + votes when the poll ended. + + .. versionadded:: 2.5 + + .. note:: + + If the poll has not ended, this will always return ``False``. + """ + return self._victor + async def voters( self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None ) -> AsyncIterator[Union[User, Member]]: @@ -325,6 +348,8 @@ class Poll: '_expiry', '_finalized', '_state', + '_total_votes', + '_victor_answer_id', ) def __init__( @@ -348,6 +373,8 @@ class Poll: self._state: Optional[ConnectionState] = None self._finalized: bool = False self._expiry: Optional[datetime.datetime] = None + self._total_votes: Optional[int] = None + self._victor_answer_id: Optional[int] = None def _update(self, message: Message) -> None: self._state = message._state @@ -360,6 +387,33 @@ class Poll: self._expiry = message.poll.expires_at self._finalized = message.poll._finalized self._answers = message.poll._answers + self._update_results_from_message(message) + + def _update_results_from_message(self, message: Message) -> None: + if message.type != MessageType.poll_result or not message.embeds: + return + + result_embed = message.embeds[0] # Will always have 1 embed + fields: Dict[str, str] = {field.name: field.value for field in result_embed.fields} # type: ignore + + total_votes = fields.get('total_votes') + + if total_votes is not None: + self._total_votes = int(total_votes) + + victor_answer = fields.get('victor_answer_id') + + if victor_answer is None: + return # Can't do anything else without the victor answer + + self._victor_answer_id = int(victor_answer) + + victor_answer_votes = fields['victor_answer_votes'] + + answer = self._answers[self._victor_answer_id] + answer._victor = True + answer._vote_count = int(victor_answer_votes) + self._answers[answer.id] = answer # Ensure update def _update_results(self, data: PollResultPayload) -> None: self._finalized = data['is_finalized'] @@ -432,6 +486,32 @@ class Poll: """List[:class:`PollAnswer`]: Returns a read-only copy of the answers.""" return list(self._answers.values()) + @property + def victor_answer_id(self) -> Optional[int]: + """Optional[:class:`int`]: The victor answer ID. + + .. versionadded:: 2.5 + + .. note:: + + This will **always** be ``None`` for polls that have not yet finished. + """ + return self._victor_answer_id + + @property + def victor_answer(self) -> Optional[PollAnswer]: + """Optional[:class:`PollAnswer`]: The victor answer. + + .. versionadded:: 2.5 + + .. note:: + + This will **always** be ``None`` for polls that have not yet finished. + """ + if self.victor_answer_id is None: + return None + return self.get_answer(self.victor_answer_id) + @property def expires_at(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: A datetime object representing the poll expiry. @@ -457,12 +537,20 @@ class Poll: @property def message(self) -> Optional[Message]: - """:class:`Message`: The message this poll is from.""" + """Optional[:class:`Message`]: The message this poll is from.""" return self._message @property def total_votes(self) -> int: - """:class:`int`: Returns the sum of all the answer votes.""" + """:class:`int`: Returns the sum of all the answer votes. + + If the poll has not yet finished, this is an approximate vote count. + + .. versionchanged:: 2.5 + This now returns an exact vote count when updated from its poll results message. + """ + if self._total_votes is not None: + return self._total_votes return sum([answer.vote_count for answer in self.answers]) def is_finalised(self) -> bool: diff --git a/discord/state.py b/discord/state.py index 8dad83a88..453fbc5b6 100644 --- a/discord/state.py +++ b/discord/state.py @@ -552,6 +552,27 @@ class ConnectionState(Generic[ClientT]): poll._handle_vote(answer_id, added, self_voted) return poll + def _update_poll_results(self, from_: Message, to: Union[Message, int]) -> None: + if isinstance(to, Message): + cached = self._get_message(to.id) + elif isinstance(to, int): + cached = self._get_message(to) + + if cached is None: + return + + to = cached + else: + return + + if to.poll is None: + return + + to.poll._update_results_from_message(from_) + + if cached is not None and cached.poll: + cached.poll._update_results_from_message(from_) + async def chunker( self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None ) -> None: diff --git a/discord/types/embed.py b/discord/types/embed.py index f2f1c5a9f..376df3a1a 100644 --- a/discord/types/embed.py +++ b/discord/types/embed.py @@ -71,7 +71,7 @@ class EmbedAuthor(TypedDict, total=False): proxy_icon_url: str -EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link'] +EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link', 'poll_result'] class Embed(TypedDict, total=False): diff --git a/discord/types/message.py b/discord/types/message.py index 1ec86681b..ae38db46f 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -174,6 +174,7 @@ MessageType = Literal[ 38, 39, 44, + 46, ] diff --git a/docs/api.rst b/docs/api.rst index 0b4015f78..b9348ec4b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1887,6 +1887,10 @@ of :class:`enum.Enum`. .. versionadded:: 2.5 + .. attribute:: poll_result + + The system message sent when a poll has closed. + .. class:: UserFlags Represents Discord User flags.