From f8f0575c19b4bca82d54bebc17b0084979b8c5e0 Mon Sep 17 00:00:00 2001 From: DA344 <108473820+DA-344@users.noreply.github.com> Date: Fri, 10 May 2024 12:14:12 +0200 Subject: [PATCH] Add support for Polls Co-authored-by: owocado <24418520+owocado@users.noreply.github.com> Co-authored-by: Josh <8677174+bijij@users.noreply.github.com> Co-authored-by: Trevor Flahardy <75498301+trevorflahardy@users.noreply.github.com> --- discord/__init__.py | 1 + discord/abc.py | 14 + discord/enums.py | 6 + discord/ext/commands/context.py | 5 + discord/http.py | 50 ++- discord/message.py | 59 +++- discord/permissions.py | 18 + discord/poll.py | 571 ++++++++++++++++++++++++++++++++ discord/raw_models.py | 32 ++ discord/state.py | 60 ++++ discord/types/gateway.py | 8 + discord/types/message.py | 2 + discord/types/poll.py | 88 +++++ discord/webhook/async_.py | 18 +- discord/webhook/sync.py | 22 +- docs/api.rst | 66 ++++ 16 files changed, 999 insertions(+), 21 deletions(-) create mode 100644 discord/poll.py create mode 100644 discord/types/poll.py diff --git a/discord/__init__.py b/discord/__init__.py index cdc6f6f94..375bf6c2f 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -60,6 +60,7 @@ from .partial_emoji import * from .payments import * from .permissions import * from .player import * +from .poll import * from .profile import * from .promotions import * from .raw_models import * diff --git a/discord/abc.py b/discord/abc.py index 13affae84..3a6c50b8d 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -96,6 +96,7 @@ if TYPE_CHECKING: StageChannel, CategoryChannel, ) + from .poll import Poll from .threads import Thread from .types.channel import ( PermissionOverwrite as PermissionOverwritePayload, @@ -1665,6 +1666,7 @@ class Messageable: mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1683,6 +1685,7 @@ class Messageable: mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1701,6 +1704,7 @@ class Messageable: mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1719,6 +1723,7 @@ class Messageable: mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1737,6 +1742,7 @@ class Messageable: mention_author: Optional[bool] = None, suppress_embeds: bool = False, silent: bool = False, + poll: Optional[Poll] = None, ) -> Message: """|coro| @@ -1806,6 +1812,10 @@ class Messageable: in the UI, but will not actually send a notification. .. versionadded:: 2.0 + poll: :class:`~discord.Poll` + The poll to send with this message. + + .. versionadded:: 2.4 Raises -------- @@ -1869,11 +1879,15 @@ class Messageable: stickers=sticker_ids, flags=flags, network_type=NetworkConnectionType.unknown, + poll=poll, ) as params: data = await state.http.send_message(channel.id, params=params) ret = state.create_message(channel=channel, data=data) + if poll: + poll._update(ret) + if delete_after is not None: await ret.delete(delay=delete_after) return ret diff --git a/discord/enums.py b/discord/enums.py index 880e6a10e..26547679d 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -127,6 +127,7 @@ __all__ = ( 'HubType', 'NetworkConnectionType', 'NetworkConnectionSpeed', + 'PollLayoutType', ) if TYPE_CHECKING: @@ -1679,6 +1680,11 @@ class NetworkConnectionSpeed(Enum): return self.value +class PollLayoutType(Enum): + default = 1 + image_only_answers = 2 + + def create_unknown_value(cls: Type[E], val: Any) -> E: value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below name = f'unknown_{val}' diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 23386a0fe..2c97fc363 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: from discord.member import Member from discord.mentions import AllowedMentions from discord.message import MessageReference, PartialMessage + from discord.poll import Poll from discord.state import ConnectionState from discord.sticker import GuildSticker, StickerItem from discord.user import ClientUser, User @@ -452,6 +453,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -471,6 +473,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -490,6 +493,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -509,6 +513,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... diff --git a/discord/http.py b/discord/http.py index 34d79cdc6..0abd8c295 100644 --- a/discord/http.py +++ b/discord/http.py @@ -81,6 +81,7 @@ if TYPE_CHECKING: from .flags import MessageFlags from .enums import ChannelType, InteractionType from .embeds import Embed + from .poll import Poll from .types import ( application, @@ -118,6 +119,7 @@ if TYPE_CHECKING: subscriptions, sticker, welcome_screen, + poll, ) from .types.snowflake import Snowflake, SnowflakeList @@ -242,6 +244,7 @@ def handle_message_parameters( network_type: NetworkConnectionType = MISSING, channel_payload: Dict[str, Any] = MISSING, applied_tags: Optional[SnowflakeList] = MISSING, + poll: Optional[Poll] = MISSING, ) -> MultipartParameters: if files is not MISSING and file is not MISSING: raise TypeError('Cannot mix file and files keyword arguments.') @@ -342,6 +345,9 @@ def handle_message_parameters( } payload.update(channel_payload) + if poll not in (MISSING, None): + payload['poll'] = poll._to_dict() # type: ignore + # Legacy uploading multipart = [] to_upload = [file for file in files if isinstance(file, File)] if files else None @@ -4362,16 +4368,50 @@ class HTTPClient: # Misc - async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: + def get_poll_answer_voters( + self, + channel_id: Snowflake, + message_id: Snowflake, + answer_id: Snowflake, + after: Optional[Snowflake] = None, + limit: Optional[int] = None, + ) -> Response[poll.PollAnswerVoters]: + params = {} + if after: + params['after'] = int(after) + if limit is not None: + params['limit'] = limit + + return self.request( + Route( + 'GET', + '/channels/{channel_id}/polls/{message_id}/answers/{answer_id}', + channel_id=channel_id, + message_id=message_id, + answer_id=answer_id, + ), + params=params, + ) + + def end_poll(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: + return self.request( + Route( + 'POST', + '/channels/{channel_id}/polls/{message_id}/expire', + channel_id=channel_id, + message_id=message_id, + ) + ) + + async def get_gateway(self, *, encoding: str = 'json', compress: Optional[str] = None) -> str: try: data = await self.request(Route('GET', '/gateway')) except HTTPException as exc: raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' + if compress: + return f'{data["url"]}?encoding={encoding}&v={INTERNAL_API_VERSION}&compress={compress}' else: - value = '{0}?encoding={1}&v={2}' - return value.format(data['url'], encoding, INTERNAL_API_VERSION) + return f'{data["url"]}?encoding={encoding}&v={INTERNAL_API_VERSION}' def get_user(self, user_id: Snowflake) -> Response[user.APIUser]: return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) diff --git a/discord/message.py b/discord/message.py index f7f74573c..72e886ccb 100644 --- a/discord/message.py +++ b/discord/message.py @@ -69,7 +69,7 @@ from .interactions import Interaction from .commands import MessageCommand from .abc import _handle_commands from .application import IntegrationApplication - +from .poll import Poll if TYPE_CHECKING: from typing_extensions import Self @@ -1266,6 +1266,7 @@ class PartialMessage(Hashable): mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1283,6 +1284,7 @@ class PartialMessage(Hashable): mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1300,6 +1302,7 @@ class PartialMessage(Hashable): mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1317,6 +1320,7 @@ class PartialMessage(Hashable): mention_author: bool = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1386,6 +1390,30 @@ class PartialMessage(Hashable): """ return await self.channel.greet(sticker, reference=self, **kwargs) + async def end_poll(self) -> Message: + """|coro| + + Ends the :class:`Poll` attached to this message. + + This can only be done if you are the message author. + + If the poll was successfully ended, then it returns the updated :class:`Message`. + + Raises + ------ + ~discord.HTTPException + Ending the poll failed. + + Returns + ------- + :class:`.Message` + The updated message. + """ + + data = await self._state.http.end_poll(self.channel.id, self.id) + + return Message(state=self._state, channel=self.channel, data=data) + def to_reference(self, *, fail_if_not_exists: bool = True) -> MessageReference: """Creates a :class:`~discord.MessageReference` from the current message. @@ -1555,6 +1583,10 @@ class Message(PartialMessage, Hashable): The interaction that this message is a response to. .. versionadded:: 2.0 + poll: Optional[:class:`Poll`] + The poll attached to this message. + + .. versionadded:: 2.4 hit: :class:`bool` Whether the message was a hit in a search result. As surrounding messages are no longer returned in search results, this is always ``True`` for search results. @@ -1608,6 +1640,7 @@ class Message(PartialMessage, Hashable): 'role_subscription', 'application_id', 'position', + 'poll', 'hit', 'total_results', 'analytics_id', @@ -1651,6 +1684,16 @@ class Message(PartialMessage, Hashable): self.application_id: Optional[int] = utils._get_as_snowflake(data, 'application_id') self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] self.call: Optional[CallMessage] = None + self.interaction: Optional[Interaction] = None + + # This updates the poll so it has the counts, if the message + # was previously cached. + self.poll: Optional[Poll] = state._get_poll(self.id) + if self.poll is None: + try: + self.poll = Poll._from_data(data=data['poll'], message=self, state=state) + except KeyError: + pass try: # If the channel doesn't have a guild attribute, we handle that @@ -1662,10 +1705,9 @@ class Message(PartialMessage, Hashable): else: guild_id = channel.guild_id # type: ignore -<<<<<<< HEAD self.guild_id: Optional[int] = guild_id self.guild = state._get_guild(guild_id) -======= + self._thread: Optional[Thread] = None if self.guild is not None: @@ -1681,9 +1723,6 @@ class Message(PartialMessage, Hashable): else: self._thread = Thread(guild=self.guild, state=state, data=thread) - self.interaction: Optional[MessageInteraction] = None ->>>>>>> 29344b9c (Add thread getters to Message) - self.application: Optional[IntegrationApplication] = None try: application = data['application'] @@ -1692,14 +1731,6 @@ class Message(PartialMessage, Hashable): else: self.application = IntegrationApplication(state=self._state, data=application) - self.interaction: Optional[Interaction] = None - try: - interaction = data['interaction'] - except KeyError: - pass - else: - self.interaction = Interaction._from_message(self, **interaction) - try: ref = data['message_reference'] except KeyError: diff --git a/discord/permissions.py b/discord/permissions.py index ce8474dde..4fc6b042b 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -734,6 +734,22 @@ class Permissions(BaseFlags): """ return 1 << 46 + @flag_value + def send_polls(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send poll messages. + + .. versionadded:: 2.4 + """ + return 1 << 49 + + @make_permission_alias('send_polls') + def create_polls(self) -> int: + """:class:`bool`: An alias for :attr:`send_polls`. + + .. versionadded:: 2.4 + """ + return 1 << 49 + def _augment_from_permissions(cls): cls.VALID_NAMES = set(Permissions.VALID_FLAGS) @@ -854,6 +870,8 @@ class PermissionOverwrite: send_voice_messages: Optional[bool] create_expressions: Optional[bool] create_events: Optional[bool] + send_polls: Optional[bool] + create_polls: Optional[bool] def __init__(self, **kwargs: Optional[bool]): self._values: Dict[str, Optional[bool]] = {} diff --git a/discord/poll.py b/discord/poll.py new file mode 100644 index 000000000..f9b2d04c5 --- /dev/null +++ b/discord/poll.py @@ -0,0 +1,571 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + + +from typing import Optional, List, TYPE_CHECKING, Union, AsyncIterator, Dict + +import datetime + +from .enums import PollLayoutType, try_enum +from . import utils +from .emoji import PartialEmoji, Emoji +from .user import User +from .object import Object +from .errors import ClientException + +if TYPE_CHECKING: + from typing_extensions import Self + + from .message import Message + from .abc import Snowflake + from .state import ConnectionState + from .member import Member + + from .types.poll import ( + PollCreate as PollCreatePayload, + PollMedia as PollMediaPayload, + PollAnswerCount as PollAnswerCountPayload, + Poll as PollPayload, + PollAnswerWithID as PollAnswerWithIDPayload, + PollResult as PollResultPayload, + PollAnswer as PollAnswerPayload, + ) + + +__all__ = ( + 'Poll', + 'PollAnswer', + 'PollMedia', +) + +MISSING = utils.MISSING +PollMediaEmoji = Union[PartialEmoji, Emoji, str] + + +class PollMedia: + """Represents the poll media for a poll item. + + .. versionadded:: 2.4 + + Attributes + ---------- + text: :class:`str` + The displayed text. + emoji: Optional[Union[:class:`PartialEmoji`, :class:`Emoji`]] + The attached emoji for this media. This is only valid for poll answers. + """ + + __slots__ = ('text', 'emoji') + + def __init__(self, /, text: str, emoji: Optional[PollMediaEmoji] = None) -> None: + self.text: str = text + self.emoji: Optional[Union[PartialEmoji, Emoji]] = PartialEmoji.from_str(emoji) if isinstance(emoji, str) else emoji + + def __repr__(self) -> str: + return f'' + + def to_dict(self) -> PollMediaPayload: + payload: PollMediaPayload = {'text': self.text} + + if self.emoji is not None: + payload['emoji'] = self.emoji._to_partial().to_dict() + + return payload + + @classmethod + def from_dict(cls, *, data: PollMediaPayload) -> Self: + emoji = data.get('emoji') + + if emoji: + return cls(text=data['text'], emoji=PartialEmoji.from_dict(emoji)) + return cls(text=data['text']) + + +class PollAnswer: + """Represents a poll's answer. + + .. container:: operations + + .. describe:: str(x) + + Returns this answer's text, if any. + + .. versionadded:: 2.4 + + Attributes + ---------- + id: :class:`int` + The ID of this answer. + media: :class:`PollMedia` + The display data for this answer. + self_voted: :class:`bool` + Whether the current user has voted to this answer or not. + """ + + __slots__ = ('media', 'id', '_state', '_message', '_vote_count', 'self_voted', '_poll') + + def __init__( + self, + *, + message: Optional[Message], + poll: Poll, + data: PollAnswerWithIDPayload, + ) -> None: + self.media: PollMedia = PollMedia.from_dict(data=data['poll_media']) + self.id: int = int(data['answer_id']) + self._message: Optional[Message] = message + self._state: Optional[ConnectionState] = message._state if message else None + self._vote_count: int = 0 + self.self_voted: bool = False + self._poll: Poll = poll + + def _handle_vote_event(self, added: bool, self_voted: bool) -> None: + if added: + self._vote_count += 1 + else: + self._vote_count -= 1 + self.self_voted = self_voted + + def _update_with_results(self, payload: PollAnswerCountPayload) -> None: + self._vote_count = int(payload['count']) + self.self_voted = payload['me_voted'] + + def __str__(self) -> str: + return self.media.text + + def __repr__(self) -> str: + return f'' + + @classmethod + def from_params( + cls, + id: int, + text: str, + emoji: Optional[PollMediaEmoji] = None, + *, + poll: Poll, + message: Optional[Message], + ) -> Self: + poll_media: PollMediaPayload = {'text': text} + if emoji is not None: + emoji = PartialEmoji.from_str(emoji) if isinstance(emoji, str) else emoji._to_partial() + emoji_data = emoji.to_dict() + # No need to remove animated key as it will be ignored + poll_media['emoji'] = emoji_data + + payload: PollAnswerWithIDPayload = {'answer_id': id, 'poll_media': poll_media} + + return cls(data=payload, message=message, poll=poll) + + @property + def text(self) -> str: + """:class:`str`: Returns this answer's displayed text.""" + return self.media.text + + @property + def emoji(self) -> Optional[Union[PartialEmoji, Emoji]]: + """Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]]: Returns this answer's displayed + emoji, if any. + """ + return self.media.emoji + + @property + def vote_count(self) -> int: + """:class:`int`: Returns an approximate count of votes for this answer. + + If the poll is finished, the count is exact. + """ + return self._vote_count + + @property + def poll(self) -> Poll: + """:class:`Poll`: Returns the parent poll of this answer""" + return self._poll + + def _to_dict(self) -> PollAnswerPayload: + return { + 'poll_media': self.media.to_dict(), + } + + async def voters( + self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None + ) -> AsyncIterator[Union[User, Member]]: + """Returns an :term:`asynchronous iterator` representing the users that have voted on this answer. + + The ``after`` parameter must represent a user + and meet the :class:`abc.Snowflake` abc. + + This can only be called when the parent poll was sent to a message. + + Examples + -------- + + Usage :: + + async for voter in poll_answer.voters(): + print(f'{voter} has voted for {poll_answer}!') + + Flattening into a list: :: + + voters = [voter async for voter in poll_answer.voters()] + # voters is now a list of User + + Parameters + ---------- + limit: Optional[:class:`int`] + The maximum number of results to return. + If not provided, returns all the users who + voted on this poll answer. + after: Optional[:class:`abc.Snowflake`] + For pagination, voters are sorted by member. + + Raises + ------ + HTTPException + Retrieving the users failed. + + Yields + ------ + Union[:class:`User`, :class:`Member`] + The member (if retrievable) or the user that has voted + on this poll answer. The case where it can be a :class:`Member` + is in a guild message context. Sometimes it can be a :class:`User` + if the member has left the guild or if the member is not cached. + """ + + if not self._message or not self._state: # Make type checker happy + raise ClientException('You cannot fetch users to a poll not sent with a message') + + if limit is None: + if not self._message.poll: + limit = 100 + else: + limit = self.vote_count or 100 + + while limit > 0: + retrieve = min(limit, 100) + + message = self._message + guild = self._message.guild + state = self._state + after_id = after.id if after else None + + data = await state.http.get_poll_answer_voters( + message.channel.id, message.id, self.id, after=after_id, limit=retrieve + ) + users = data['users'] + + if len(users) == 0: + # No more voters to fetch, terminate loop + break + + limit -= len(users) + after = Object(id=int(users[-1]['id'])) + + if not guild or isinstance(guild, Object): + for raw_user in reversed(users): + yield User(state=self._state, data=raw_user) + continue + + for raw_member in reversed(users): + member_id = int(raw_member['id']) + member = guild.get_member(member_id) + + yield member or User(state=self._state, data=raw_member) + + +class Poll: + """Represents a message's Poll. + + .. versionadded:: 2.4 + + Parameters + ---------- + question: Union[:class:`PollMedia`, :class:`str`] + The poll's displayed question. The text can be up to 300 characters. + duration: :class:`datetime.timedelta` + The duration of the poll. Duration must be in hours. + multiple: :class:`bool` + Whether users are allowed to select more than one answer. + Defaultsto ``False``. + layout_type: :class:`PollLayoutType` + The layout type of the poll. Defaults to :attr:`PollLayoutType.default`. + """ + + __slots__ = ( + 'multiple', + '_answers', + 'duration', + 'layout_type', + '_question_media', + '_message', + '_expiry', + '_finalized', + '_state', + ) + + def __init__( + self, + question: Union[PollMedia, str], + duration: datetime.timedelta, + *, + multiple: bool = False, + layout_type: PollLayoutType = PollLayoutType.default, + ) -> None: + self._question_media: PollMedia = PollMedia(text=question, emoji=None) if isinstance(question, str) else question + self._answers: Dict[int, PollAnswer] = {} + self.duration: datetime.timedelta = duration + + self.multiple: bool = multiple + self.layout_type: PollLayoutType = layout_type + + # NOTE: These attributes are set manually when calling + # _from_data, so it should be ``None`` now. + self._message: Optional[Message] = None + self._state: Optional[ConnectionState] = None + self._finalized: bool = False + self._expiry: Optional[datetime.datetime] = None + + def _update(self, message: Message) -> None: + self._state = message._state + self._message = message + + if not message.poll: + return + + # The message's poll contains the more up to date data. + self._expiry = message.poll.expires_at + self._finalized = message.poll._finalized + + def _update_results(self, data: PollResultPayload) -> None: + self._finalized = data['is_finalized'] + + for count in data['answer_counts']: + answer = self.get_answer(int(count['id'])) + if not answer: + continue + + answer._update_with_results(count) + + def _handle_vote(self, answer_id: int, added: bool, self_voted: bool = False): + answer = self.get_answer(answer_id) + if not answer: + return + + answer._handle_vote_event(added, self_voted) + + @classmethod + def _from_data(cls, *, data: PollPayload, message: Message, state: ConnectionState) -> Self: + multiselect = data.get('allow_multiselect', False) + layout_type = try_enum(PollLayoutType, data.get('layout_type', 1)) + question_data = data.get('question') + question = question_data.get('text') + expiry = utils.parse_time(data['expiry']) # If obtained via API, then expiry is set. + duration = expiry - message.created_at + # self.created_at = message.created_at + # duration = self.created_at - expiry + + if (duration.total_seconds() / 3600) > 168: # As the duration may exceed little milliseconds then we fix it + duration = datetime.timedelta(days=7) + + self = cls( + duration=duration, + multiple=multiselect, + layout_type=layout_type, + question=question, + ) + self._answers = { + int(answer['answer_id']): PollAnswer(data=answer, message=message, poll=self) for answer in data['answers'] + } + self._message = message + self._state = state + self._expiry = expiry + + try: + self._update_results(data['results']) + except KeyError: + pass + + return self + + def _to_dict(self) -> PollCreatePayload: + data: PollCreatePayload = { + 'allow_multiselect': self.multiple, + 'question': self._question_media.to_dict(), + 'duration': self.duration.total_seconds() / 3600, + 'layout_type': self.layout_type.value, + 'answers': [answer._to_dict() for answer in self.answers], + } + return data + + def __repr__(self) -> str: + return f"" + + @property + def question(self) -> str: + """:class:`str`: Returns this poll answer question string.""" + return self._question_media.text + + @property + def answers(self) -> List[PollAnswer]: + """List[:class:`PollAnswer`]: Returns a read-only copy of the answers""" + return list(self._answers.values()) + + @property + def expires_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: A datetime object representing the poll expiry. + + .. note:: + + This will **always** be ``None`` for stateless polls. + """ + return self._expiry + + @property + def created_at(self) -> Optional[datetime.datetime]: + """:class:`datetime.datetime`: Returns the poll's creation time, or ``None`` if user-created.""" + + if not self._message: + return + return self._message.created_at + + @property + def message(self) -> Optional[Message]: + """:class:`Message`: The message this poll is from.""" + return self._message + + @property + def total_votes(self) -> int: + """:class:`int`: Returns the sum of all the answer votes.""" + return sum([answer.vote_count for answer in self.answers]) + + def is_finalised(self) -> bool: + """:class:`bool`: Returns whether the poll has finalised. + + This always returns ``False`` for stateless polls. + """ + return self._finalized + + is_finalized = is_finalised + + def copy(self) -> Self: + """Returns a stateless copy of this poll. + + This is meant to be used when you want to edit a stateful poll. + + Returns + ------- + :class:`Poll` + The copy of the poll. + """ + + new = self.__class__(question=self.question, duration=self.duration) + + # We want to return a stateless copy of the poll, so we should not + # override new._answers as our answers may contain a state + for answer in self.answers: + new.add_answer(text=answer.text, emoji=answer.emoji) + + return new + + def add_answer( + self, + *, + text: str, + emoji: Optional[Union[PartialEmoji, Emoji, str]] = None, + ) -> Self: + """Appends a new answer to this poll. + + Parameters + ---------- + text: :class:`str` + The text label for this poll answer. Can be up to 55 + characters. + emoji: Union[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`] + The emoji to display along the text. + + Raises + ------ + ClientException + Cannot append answers to a poll that is active. + + Returns + ------- + :class:`Poll` + This poll with the new answer appended. This allows fluent-style chaining. + """ + + if self._message: + raise ClientException('Cannot append answers to a poll that is active') + + answer = PollAnswer.from_params(id=len(self.answers) + 1, text=text, emoji=emoji, message=self._message, poll=self) + self._answers[answer.id] = answer + return self + + def get_answer( + self, + /, + id: int, + ) -> Optional[PollAnswer]: + """Returns the answer with the provided ID or ``None`` if not found. + + Parameters + ---------- + id: :class:`int` + The ID of the answer to get. + + Returns + ------- + Optional[:class:`PollAnswer`] + The answer. + """ + + return self._answers.get(id) + + async def end(self) -> Self: + """|coro| + + Ends the poll. + + Raises + ------ + ClientException + This poll has no attached message. + HTTPException + Ending the poll failed. + + Returns + ------- + :class:`Poll` + The updated poll. + """ + + if not self._message or not self._state: # Make type checker happy + raise ClientException('This poll has no attached message.') + + self._message = await self._message.end_poll() + + return self diff --git a/discord/raw_models.py b/discord/raw_models.py index 66ce02b05..893e3a615 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: MessageReactionRemoveEvent, MessageUpdateEvent, NonChannelAckEvent, + PollVoteActionEvent, ThreadDeleteEvent, ThreadMembersUpdate, ) @@ -72,6 +73,7 @@ __all__ = ( 'RawMessageAckEvent', 'RawUserFeatureAckEvent', 'RawGuildFeatureAckEvent', + 'RawPollVoteActionEvent', ) @@ -500,3 +502,33 @@ class RawGuildFeatureAckEvent(RawUserFeatureAckEvent): def guild(self) -> Guild: """:class:`Guild`: The guild that the feature was acknowledged in.""" return self._state._get_or_create_unavailable_guild(self.guild_id) + + +class RawPollVoteActionEvent(_RawReprMixin): + """Represents the payload for a :func:`on_raw_poll_vote_add` or :func:`on_raw_poll_vote_remove` + event. + + .. versionadded:: 2.4 + + Attributes + ---------- + user_id: :class:`int` + The ID of the user that added or removed a vote. + channel_id: :class:`int` + The channel ID where the poll vote action took place. + message_id: :class:`int` + The message ID that contains the poll the user added or removed their vote on. + guild_id: Optional[:class:`int`] + The guild ID where the vote got added or removed, if applicable.. + answer_id: :class:`int` + The poll answer's ID the user voted on. + """ + + __slots__ = ('user_id', 'channel_id', 'message_id', 'guild_id', 'answer_id') + + def __init__(self, data: PollVoteActionEvent) -> None: + self.user_id: int = int(data['user_id']) + self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data['message_id']) + self.guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') + self.answer_id: int = int(data['answer_id']) diff --git a/discord/state.py b/discord/state.py index a29b81013..959e90c07 100644 --- a/discord/state.py +++ b/discord/state.py @@ -115,6 +115,7 @@ if TYPE_CHECKING: from .client import Client from .gateway import DiscordWebSocket from .calls import Call + from .poll import Poll from .types.automod import AutoModerationRule, AutoModerationActionExecution from .types.snowflake import Snowflake @@ -1317,6 +1318,12 @@ class ConnectionState: else utils.find(lambda m: m.id == msg_id, reversed(self._call_message_cache.values())) ) + def _get_poll(self, msg_id: Optional[int]) -> Optional[Poll]: + message = self._get_message(msg_id) + if not message: + return + return message.poll + def _add_guild_from_data(self, data: GuildPayload) -> Guild: guild = self.create_guild(data) self._add_guild(guild) @@ -1360,6 +1367,13 @@ class ConnectionState: except NotFound: pass + def _update_poll_counts(self, message: Message, answer_id: int, added: bool, self_voted: bool = False) -> Optional[Poll]: + poll = message.poll + if not poll: + return + poll._handle_vote(answer_id, added, self_voted) + return poll + def subscribe_guild( self, guild: Guild, typing: bool = True, activities: bool = True, threads: bool = True, member_updates: bool = True ) -> Coroutine: @@ -3476,6 +3490,52 @@ class ConnectionState: parse_nothing = lambda *_: None # parse_guild_application_commands_update = parse_nothing # Grabbed directly in command iterators + def parse_message_poll_vote_add(self, data: gw.PollVoteActionEvent) -> None: + raw = RawPollVoteActionEvent(data) + + self.dispatch('raw_poll_vote_add', raw) + + message = self._get_message(raw.message_id) + guild = self._get_guild(raw.guild_id) + + if guild: + user = guild.get_member(raw.user_id) + else: + user = self.get_user(raw.user_id) + + if message and user: + poll = self._update_poll_counts(message, raw.answer_id, True, raw.user_id == self.self_id) + if not poll: + _log.warning( + 'POLL_VOTE_ADD referencing message with ID: %s does not have a poll. Discarding.', raw.message_id + ) + return + + self.dispatch('poll_vote_add', user, poll.get_answer(raw.answer_id)) + + def parse_message_poll_vote_remove(self, data: gw.PollVoteActionEvent) -> None: + raw = RawPollVoteActionEvent(data) + + self.dispatch('raw_poll_vote_remove', raw) + + message = self._get_message(raw.message_id) + guild = self._get_guild(raw.guild_id) + + if guild: + user = guild.get_member(raw.user_id) + else: + user = self.get_user(raw.user_id) + + if message and user: + poll = self._update_poll_counts(message, raw.answer_id, False, raw.user_id == self.self_id) + if not poll: + _log.warning( + 'POLL_VOTE_REMOVE referencing message with ID: %s does not have a poll. Discarding.', raw.message_id + ) + return + + self.dispatch('poll_vote_remove', user, poll.get_answer(raw.answer_id)) + def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, (TextChannel, Thread, VoiceChannel, StageChannel)): return channel.guild.get_member(user_id) diff --git a/discord/types/gateway.py b/discord/types/gateway.py index 64c796fd8..ee9902e7f 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -701,3 +701,11 @@ class GuildMemberListUpdateEvent(TypedDict): online_count: int groups: List[GuildMemberListGroup] ops: List[GuildMemberListOP] + + +class PollVoteActionEvent(TypedDict): + user_id: Snowflake + channel_id: Snowflake + message_id: Snowflake + guild_id: NotRequired[Snowflake] + answer_id: int diff --git a/discord/types/message.py b/discord/types/message.py index fa1a558a4..008a24dfe 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -38,6 +38,7 @@ from .interactions import MessageInteraction from .application import BaseApplication from .sticker import StickerItem from .threads import Thread, ThreadMember +from .poll import Poll class PartialMessage(TypedDict): @@ -161,6 +162,7 @@ class Message(PartialMessage): attachments: List[Attachment] embeds: List[Embed] pinned: bool + poll: NotRequired[Poll] type: MessageType member: NotRequired[Member] mention_channels: NotRequired[List[ChannelMention]] diff --git a/discord/types/poll.py b/discord/types/poll.py new file mode 100644 index 000000000..fabdbd48f --- /dev/null +++ b/discord/types/poll.py @@ -0,0 +1,88 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + + +from typing import List, TypedDict, Optional, Literal, TYPE_CHECKING +from typing_extensions import NotRequired + +from .snowflake import Snowflake + +if TYPE_CHECKING: + from .user import User + from .emoji import PartialEmoji + + +LayoutType = Literal[1] # 1 = Default + + +class PollMedia(TypedDict): + text: str + emoji: NotRequired[Optional[PartialEmoji]] + + +class PollAnswer(TypedDict): + poll_media: PollMedia + + +class PollAnswerWithID(PollAnswer): + answer_id: int + + +class PollAnswerCount(TypedDict): + id: Snowflake + count: int + me_voted: bool + + +class PollAnswerVoters(TypedDict): + users: List[User] + + +class PollResult(TypedDict): + is_finalized: bool + answer_counts: List[PollAnswerCount] + + +class PollCreate(TypedDict): + allow_multiselect: bool + answers: List[PollAnswer] + duration: float + layout_type: LayoutType + question: PollMedia + + +# We don't subclass Poll as it will +# still have the duration field, which +# is converted into expiry when poll is +# fetched from a message or returned +# by a `send` method in a Messageable +class Poll(TypedDict): + allow_multiselect: bool + answers: List[PollAnswerWithID] + expiry: str + layout_type: LayoutType + question: PollMedia + results: PollResult diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index e6cff254d..ee4c08444 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -71,7 +71,7 @@ if TYPE_CHECKING: from ..emoji import Emoji from ..channel import VoiceChannel from ..abc import Snowflake - import datetime + from ..poll import Poll from ..types.webhook import ( Webhook as WebhookPayload, SourceGuild as SourceGuildPayload, @@ -1373,6 +1373,7 @@ class Webhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> WebhookMessage: ... @@ -1395,6 +1396,7 @@ class Webhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> None: ... @@ -1416,6 +1418,7 @@ class Webhook(BaseWebhook): suppress_embeds: bool = False, silent: bool = False, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> Optional[WebhookMessage]: """|coro| @@ -1492,6 +1495,15 @@ class Webhook(BaseWebhook): .. versionadded:: 2.1 + poll: :class:`Poll` + The poll to send with this message. + + .. warning:: + + When sending a Poll via webhook, you cannot manually end it. + + .. versionadded:: 2.4 + Raises -------- HTTPException @@ -1551,6 +1563,7 @@ class Webhook(BaseWebhook): allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, applied_tags=applied_tag_ids, + poll=poll, ) as params: adapter = async_context.get() thread_id: Optional[int] = None @@ -1574,6 +1587,9 @@ class Webhook(BaseWebhook): if wait: msg = self._create_message(data, thread=thread) + if poll is not MISSING and msg: + poll._update(msg) + return msg async def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> WebhookMessage: diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index a5eebfbca..13cb357cc 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: from ..file import File from ..embeds import Embed + from ..poll import Poll from ..mentions import AllowedMentions from ..message import Attachment from ..abc import Snowflake @@ -871,6 +872,7 @@ class SyncWebhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> SyncWebhookMessage: ... @@ -893,6 +895,7 @@ class SyncWebhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> None: ... @@ -914,6 +917,7 @@ class SyncWebhook(BaseWebhook): suppress_embeds: bool = False, silent: bool = False, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> Optional[SyncWebhookMessage]: """Sends a message using the webhook. @@ -978,6 +982,14 @@ class SyncWebhook(BaseWebhook): in the UI, but will not actually send a notification. .. versionadded:: 2.0 + poll: :class:`Poll` + The poll to send with this message. + + .. warning:: + + When sending a Poll via webhook, you cannot manually end it. + + .. versionadded:: 2.1 Raises -------- @@ -1036,6 +1048,7 @@ class SyncWebhook(BaseWebhook): previous_allowed_mentions=previous_mentions, flags=flags, applied_tags=applied_tag_ids, + poll=poll, ) as params: adapter: WebhookAdapter = _get_webhook_adapter() thread_id: Optional[int] = None @@ -1053,8 +1066,15 @@ class SyncWebhook(BaseWebhook): wait=wait, ) + msg = None + if wait: - return self._create_message(data, thread=thread) + msg = self._create_message(data, thread=thread) + + if poll is not MISSING and msg: + poll._update(msg) + + return msg def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> SyncWebhookMessage: """Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook. diff --git a/docs/api.rst b/docs/api.rst index 300eb2dca..1459c1059 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1263,6 +1263,38 @@ Messages :param message_id: The ID of the message that was deleted. :type message_id: :class:`int` +Polls +~~~~~~ + +.. function:: on_poll_vote_add(user, answer) + on_poll_vote_remove(user, answer) + + Called when a :class:`Poll` gains or loses a vote. If the ``user`` or ``message`` + are not cached then this event will not be called. + + .. note:: + + If the poll allows multiple answers and the user removes or adds multiple votes, this + event will be called as many times as votes that are added or removed. + + .. versionadded:: 2.1 + + :param user: The user that performed the action. + :type user: Union[:class:`User`, :class:`Member`] + :param answer: The answer the user voted or removed their vote from. + :type answer: :class:`PollAnswer` + +.. function:: on_raw_poll_vote_add(payload) + on_raw_poll_vote_remove(payload) + + Called when a :class:`Poll` gains or loses a vote. Unlike :func:`on_poll_vote_add` and :func:`on_poll_vote_remove` + this is called regardless of the state of the internal user and message cache. + + .. versionadded:: 2.1 + + :param payload: The raw event payload data. + :type payload: :class:`RawPollVoteActionEvent` + Reactions ~~~~~~~~~~ @@ -6011,6 +6043,16 @@ of :class:`enum.Enum`. An alias for :attr:`college`. +.. class:: PollLayoutType + + Represents how a poll answers are shown + + .. versionadded:: 2.4 + + .. attribute:: default + + The default layout. + .. _discord-api-audit-logs: Audit Log Data @@ -7973,6 +8015,11 @@ RawEvent .. autoclass:: RawMessageUpdateEvent() :members: +.. attributetable:: RawPollVoteActionEvent + +.. autoclass:: RawPollVoteActionEvent() + :members: + .. attributetable:: RawReactionActionEvent .. autoclass:: RawReactionActionEvent() @@ -8323,6 +8370,25 @@ Flags .. autoclass:: SystemChannelFlags() :members: +Poll +~~~~ + +.. attributetable:: Poll + +.. autoclass:: Poll() + :members: + +.. attributetable:: PollAnswer + +.. autoclass:: PollAnswer() + :members: + :inherited-members: + +.. attributetable:: PollMedia + +.. autoclass:: PollMedia() + :members: + Exceptions ------------