diff --git a/discord/__init__.py b/discord/__init__.py index d239c8f3b..e3148e513 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -69,6 +69,7 @@ from .interactions import * from .components import * from .threads import * from .automod import * +from .poll import * class VersionInfo(NamedTuple): diff --git a/discord/abc.py b/discord/abc.py index 8eeb9d4d0..656a38659 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -92,6 +92,7 @@ if TYPE_CHECKING: VoiceChannel, StageChannel, ) + from .poll import Poll from .threads import Thread from .ui.view import View from .types.channel import ( @@ -1350,6 +1351,7 @@ class Messageable: view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1370,6 +1372,7 @@ class Messageable: view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1390,6 +1393,7 @@ class Messageable: view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1410,6 +1414,7 @@ class Messageable: view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1431,6 +1436,7 @@ class Messageable: view: Optional[View] = None, suppress_embeds: bool = False, silent: bool = False, + poll: Optional[Poll] = None, ) -> Message: """|coro| @@ -1516,6 +1522,10 @@ class Messageable: in the UI, but will not actually send a notification. .. versionadded:: 2.2 + poll: :class:`~discord.Poll` + The poll to send with this message. + + .. versionadded:: 2.4 Raises -------- @@ -1582,6 +1592,7 @@ class Messageable: stickers=sticker_ids, view=view, flags=flags, + poll=poll, ) as params: data = await state.http.send_message(channel.id, params=params) @@ -1589,6 +1600,9 @@ class Messageable: if view and not view.is_finished(): state.store_view(view, ret.id) + if poll: + poll._update(ret) + if delete_after is not None: await ret.delete(delay=delete_after) return ret diff --git a/discord/client.py b/discord/client.py index f452ca30a..a91be7160 100644 --- a/discord/client.py +++ b/discord/client.py @@ -107,6 +107,7 @@ if TYPE_CHECKING: RawThreadMembersUpdate, RawThreadUpdateEvent, RawTypingEvent, + RawPollVoteActionEvent, ) from .reaction import Reaction from .role import Role @@ -116,6 +117,7 @@ if TYPE_CHECKING: from .ui.item import Item from .voice_client import VoiceProtocol from .audit_logs import AuditLogEntry + from .poll import PollAnswer # fmt: off @@ -1815,6 +1817,30 @@ class Client: ) -> Tuple[Member, VoiceState, VoiceState]: ... + # Polls + + @overload + async def wait_for( + self, + event: Literal['poll_vote_add', 'poll_vote_remove'], + /, + *, + check: Optional[Callable[[Union[User, Member], PollAnswer], bool]] = None, + timeout: Optional[float] = None, + ) -> Tuple[Union[User, Member], PollAnswer]: + ... + + @overload + async def wait_for( + self, + event: Literal['raw_poll_vote_add', 'raw_poll_vote_remove'], + /, + *, + check: Optional[Callable[[RawPollVoteActionEvent], bool]] = None, + timeout: Optional[float] = None, + ) -> RawPollVoteActionEvent: + ... + # Commands @overload diff --git a/discord/enums.py b/discord/enums.py index f1af2d790..f7989a195 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -73,6 +73,7 @@ __all__ = ( 'SKUType', 'EntitlementType', 'EntitlementOwnerType', + 'PollLayoutType', ) @@ -818,6 +819,10 @@ class EntitlementOwnerType(Enum): user = 2 +class PollLayoutType(Enum): + default = 1 + + def create_unknown_value(cls: Type[E], val: Any) -> E: value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below name = f'unknown_{val}' diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index d4052cbbd..ad9c286ee 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -50,6 +50,7 @@ if TYPE_CHECKING: from discord.message import MessageReference, PartialMessage from discord.ui import View from discord.types.interactions import ApplicationCommandInteractionData + from discord.poll import Poll from .cog import Cog from .core import Command @@ -641,6 +642,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -662,6 +664,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -683,6 +686,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -704,6 +708,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -826,6 +831,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -847,6 +853,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -868,6 +875,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -889,6 +897,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = ..., ephemeral: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -911,6 +920,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): suppress_embeds: bool = False, ephemeral: bool = False, silent: bool = False, + poll: Poll = MISSING, ) -> Message: """|coro| @@ -1000,6 +1010,11 @@ class Context(discord.abc.Messageable, Generic[BotT]): .. versionadded:: 2.2 + poll: :class:`~discord.Poll` + The poll to send with this message. + + .. versionadded:: 2.4 + Raises -------- ~discord.HTTPException @@ -1037,6 +1052,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): view=view, suppress_embeds=suppress_embeds, silent=silent, + poll=poll, ) # type: ignore # The overloads don't support Optional but the implementation does # Convert the kwargs from None to MISSING to appease the remaining implementations @@ -1052,6 +1068,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): 'suppress_embeds': suppress_embeds, 'ephemeral': ephemeral, 'silent': silent, + 'poll': poll, } if self.interaction.response.is_done(): diff --git a/discord/flags.py b/discord/flags.py index 249c2e8f6..3d31e3a58 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -1257,6 +1257,57 @@ class Intents(BaseFlags): """ return 1 << 21 + @alias_flag_value + def polls(self): + """:class:`bool`: Whether guild and direct messages poll related events are enabled. + + This is a shortcut to set or get both :attr:`guild_polls` and :attr:`dm_polls`. + + This corresponds to the following events: + + - :func:`on_poll_vote_add` (both guilds and DMs) + - :func:`on_poll_vote_remove` (both guilds and DMs) + - :func:`on_raw_poll_vote_add` (both guilds and DMs) + - :func:`on_raw_poll_vote_remove` (both guilds and DMs) + + .. versionadded:: 2.4 + """ + return (1 << 24) | (1 << 25) + + @flag_value + def guild_polls(self): + """:class:`bool`: Whether guild poll related events are enabled. + + See also :attr:`dm_polls` and :attr:`polls`. + + This corresponds to the following events: + + - :func:`on_poll_vote_add` (only for guilds) + - :func:`on_poll_vote_remove` (only for guilds) + - :func:`on_raw_poll_vote_add` (only for guilds) + - :func:`on_raw_poll_vote_remove` (only for guilds) + + .. versionadded:: 2.4 + """ + return 1 << 24 + + @flag_value + def dm_polls(self): + """:class:`bool`: Whether direct messages poll related events are enabled. + + See also :attr:`guild_polls` and :attr:`polls`. + + This corresponds to the following events: + + - :func:`on_poll_vote_add` (only for DMs) + - :func:`on_poll_vote_remove` (only for DMs) + - :func:`on_raw_poll_vote_add` (only for DMs) + - :func:`on_raw_poll_vote_remove` (only for DMs) + + .. versionadded:: 2.4 + """ + return 1 << 25 + @fill_with_flags() class MemberCacheFlags(BaseFlags): diff --git a/discord/http.py b/discord/http.py index f36d191e4..aab710580 100644 --- a/discord/http.py +++ b/discord/http.py @@ -68,6 +68,7 @@ if TYPE_CHECKING: from .embeds import Embed from .message import Attachment from .flags import MessageFlags + from .poll import Poll from .types import ( appinfo, @@ -91,6 +92,7 @@ if TYPE_CHECKING: sticker, welcome_screen, sku, + poll, ) from .types.snowflake import Snowflake, SnowflakeList @@ -154,6 +156,7 @@ def handle_message_parameters( thread_name: str = MISSING, channel_payload: Dict[str, Any] = MISSING, applied_tags: Optional[SnowflakeList] = MISSING, + poll: Optional[Poll] = MISSING, ) -> MultipartParameters: if files is not MISSING and file is not MISSING: raise TypeError('Cannot mix file and files keyword arguments.') @@ -256,6 +259,9 @@ def handle_message_parameters( } payload.update(channel_payload) + if poll not in (MISSING, None): + payload['poll'] = poll._to_dict() # type: ignore + multipart = [] if files: multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)}) @@ -2513,6 +2519,43 @@ class HTTPClient: payload = {k: v for k, v in payload.items() if k in valid_keys} return self.request(Route('PATCH', '/applications/@me'), json=payload, reason=reason) + def get_poll_answer_voters( + self, + channel_id: Snowflake, + message_id: Snowflake, + answer_id: Snowflake, + after: Optional[Snowflake] = None, + limit: Optional[int] = None, + ) -> Response[poll.PollAnswerVoters]: + params = {} + + if after: + params['after'] = int(after) + + if limit is not None: + params['limit'] = limit + + return self.request( + Route( + 'GET', + '/channels/{channel_id}/polls/{message_id}/answers/{answer_id}', + channel_id=channel_id, + message_id=message_id, + answer_id=answer_id, + ), + params=params, + ) + + def end_poll(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: + return self.request( + Route( + 'POST', + '/channels/{channel_id}/polls/{message_id}/expire', + channel_id=channel_id, + message_id=message_id, + ) + ) + async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: try: data = await self.request(Route('GET', '/gateway')) diff --git a/discord/interactions.py b/discord/interactions.py index f471e2040..5702e8b8d 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: from .channel import VoiceChannel, StageChannel, TextChannel, ForumChannel, CategoryChannel, DMChannel, GroupChannel from .threads import Thread from .app_commands.commands import Command, ContextMenu + from .poll import Poll InteractionChannel = Union[ VoiceChannel, @@ -762,6 +763,7 @@ class InteractionResponse(Generic[ClientT]): suppress_embeds: bool = False, silent: bool = False, delete_after: Optional[float] = None, + poll: Poll = MISSING, ) -> None: """|coro| @@ -842,6 +844,7 @@ class InteractionResponse(Generic[ClientT]): allowed_mentions=allowed_mentions, flags=flags, view=view, + poll=poll, ) http = parent._state.http diff --git a/discord/message.py b/discord/message.py index aa1609826..ea62b87f6 100644 --- a/discord/message.py +++ b/discord/message.py @@ -63,6 +63,7 @@ from .mixins import Hashable from .sticker import StickerItem, GuildSticker from .threads import Thread from .channel import PartialMessageable +from .poll import Poll if TYPE_CHECKING: from typing_extensions import Self @@ -1464,6 +1465,7 @@ class PartialMessage(Hashable): view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1484,6 +1486,7 @@ class PartialMessage(Hashable): view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1504,6 +1507,7 @@ class PartialMessage(Hashable): view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1524,6 +1528,7 @@ class PartialMessage(Hashable): view: View = ..., suppress_embeds: bool = ..., silent: bool = ..., + poll: Poll = ..., ) -> Message: ... @@ -1558,6 +1563,30 @@ class PartialMessage(Hashable): return await self.channel.send(content, reference=self, **kwargs) + async def end_poll(self) -> Message: + """|coro| + + Ends the :class:`Poll` attached to this message. + + This can only be done if you are the message author. + + If the poll was successfully ended, then it returns the updated :class:`Message`. + + Raises + ------ + ~discord.HTTPException + Ending the poll failed. + + Returns + ------- + :class:`.Message` + The updated message. + """ + + data = await self._state.http.end_poll(self.channel.id, self.id) + + return Message(state=self._state, channel=self.channel, data=data) + def to_reference(self, *, fail_if_not_exists: bool = True) -> MessageReference: """Creates a :class:`~discord.MessageReference` from the current message. @@ -1728,6 +1757,10 @@ class Message(PartialMessage, Hashable): interaction_metadata: Optional[:class:`.MessageInteractionMetadata`] The metadata of the interaction that this message is a response to. + .. versionadded:: 2.4 + poll: Optional[:class:`Poll`] + The poll attached to this message. + .. versionadded:: 2.4 """ @@ -1764,6 +1797,7 @@ class Message(PartialMessage, Hashable): 'application_id', 'position', 'interaction_metadata', + 'poll', ) if TYPE_CHECKING: @@ -1803,6 +1837,15 @@ class Message(PartialMessage, Hashable): self.application_id: Optional[int] = utils._get_as_snowflake(data, 'application_id') self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] + # This updates the poll so it has the counts, if the message + # was previously cached. + self.poll: Optional[Poll] = state._get_poll(self.id) + if self.poll is None: + try: + self.poll = Poll._from_data(data=data['poll'], message=self, state=state) + except KeyError: + pass + try: # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild diff --git a/discord/permissions.py b/discord/permissions.py index f18f94a7a..916fa4d2f 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -730,6 +730,22 @@ class Permissions(BaseFlags): """ return 1 << 46 + @flag_value + def send_polls(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send poll messages. + + .. versionadded:: 2.4 + """ + return 1 << 49 + + @make_permission_alias('send_polls') + def create_polls(self) -> int: + """:class:`bool`: An alias for :attr:`send_polls`. + + .. versionadded:: 2.4 + """ + return 1 << 49 + def _augment_from_permissions(cls): cls.VALID_NAMES = set(Permissions.VALID_FLAGS) @@ -850,6 +866,8 @@ class PermissionOverwrite: send_voice_messages: Optional[bool] create_expressions: Optional[bool] create_events: Optional[bool] + send_polls: Optional[bool] + create_polls: Optional[bool] def __init__(self, **kwargs: Optional[bool]): self._values: Dict[str, Optional[bool]] = {} diff --git a/discord/poll.py b/discord/poll.py new file mode 100644 index 000000000..f9b2d04c5 --- /dev/null +++ b/discord/poll.py @@ -0,0 +1,571 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + + +from typing import Optional, List, TYPE_CHECKING, Union, AsyncIterator, Dict + +import datetime + +from .enums import PollLayoutType, try_enum +from . import utils +from .emoji import PartialEmoji, Emoji +from .user import User +from .object import Object +from .errors import ClientException + +if TYPE_CHECKING: + from typing_extensions import Self + + from .message import Message + from .abc import Snowflake + from .state import ConnectionState + from .member import Member + + from .types.poll import ( + PollCreate as PollCreatePayload, + PollMedia as PollMediaPayload, + PollAnswerCount as PollAnswerCountPayload, + Poll as PollPayload, + PollAnswerWithID as PollAnswerWithIDPayload, + PollResult as PollResultPayload, + PollAnswer as PollAnswerPayload, + ) + + +__all__ = ( + 'Poll', + 'PollAnswer', + 'PollMedia', +) + +MISSING = utils.MISSING +PollMediaEmoji = Union[PartialEmoji, Emoji, str] + + +class PollMedia: + """Represents the poll media for a poll item. + + .. versionadded:: 2.4 + + Attributes + ---------- + text: :class:`str` + The displayed text. + emoji: Optional[Union[:class:`PartialEmoji`, :class:`Emoji`]] + The attached emoji for this media. This is only valid for poll answers. + """ + + __slots__ = ('text', 'emoji') + + def __init__(self, /, text: str, emoji: Optional[PollMediaEmoji] = None) -> None: + self.text: str = text + self.emoji: Optional[Union[PartialEmoji, Emoji]] = PartialEmoji.from_str(emoji) if isinstance(emoji, str) else emoji + + def __repr__(self) -> str: + return f'' + + def to_dict(self) -> PollMediaPayload: + payload: PollMediaPayload = {'text': self.text} + + if self.emoji is not None: + payload['emoji'] = self.emoji._to_partial().to_dict() + + return payload + + @classmethod + def from_dict(cls, *, data: PollMediaPayload) -> Self: + emoji = data.get('emoji') + + if emoji: + return cls(text=data['text'], emoji=PartialEmoji.from_dict(emoji)) + return cls(text=data['text']) + + +class PollAnswer: + """Represents a poll's answer. + + .. container:: operations + + .. describe:: str(x) + + Returns this answer's text, if any. + + .. versionadded:: 2.4 + + Attributes + ---------- + id: :class:`int` + The ID of this answer. + media: :class:`PollMedia` + The display data for this answer. + self_voted: :class:`bool` + Whether the current user has voted to this answer or not. + """ + + __slots__ = ('media', 'id', '_state', '_message', '_vote_count', 'self_voted', '_poll') + + def __init__( + self, + *, + message: Optional[Message], + poll: Poll, + data: PollAnswerWithIDPayload, + ) -> None: + self.media: PollMedia = PollMedia.from_dict(data=data['poll_media']) + self.id: int = int(data['answer_id']) + self._message: Optional[Message] = message + self._state: Optional[ConnectionState] = message._state if message else None + self._vote_count: int = 0 + self.self_voted: bool = False + self._poll: Poll = poll + + def _handle_vote_event(self, added: bool, self_voted: bool) -> None: + if added: + self._vote_count += 1 + else: + self._vote_count -= 1 + self.self_voted = self_voted + + def _update_with_results(self, payload: PollAnswerCountPayload) -> None: + self._vote_count = int(payload['count']) + self.self_voted = payload['me_voted'] + + def __str__(self) -> str: + return self.media.text + + def __repr__(self) -> str: + return f'' + + @classmethod + def from_params( + cls, + id: int, + text: str, + emoji: Optional[PollMediaEmoji] = None, + *, + poll: Poll, + message: Optional[Message], + ) -> Self: + poll_media: PollMediaPayload = {'text': text} + if emoji is not None: + emoji = PartialEmoji.from_str(emoji) if isinstance(emoji, str) else emoji._to_partial() + emoji_data = emoji.to_dict() + # No need to remove animated key as it will be ignored + poll_media['emoji'] = emoji_data + + payload: PollAnswerWithIDPayload = {'answer_id': id, 'poll_media': poll_media} + + return cls(data=payload, message=message, poll=poll) + + @property + def text(self) -> str: + """:class:`str`: Returns this answer's displayed text.""" + return self.media.text + + @property + def emoji(self) -> Optional[Union[PartialEmoji, Emoji]]: + """Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]]: Returns this answer's displayed + emoji, if any. + """ + return self.media.emoji + + @property + def vote_count(self) -> int: + """:class:`int`: Returns an approximate count of votes for this answer. + + If the poll is finished, the count is exact. + """ + return self._vote_count + + @property + def poll(self) -> Poll: + """:class:`Poll`: Returns the parent poll of this answer""" + return self._poll + + def _to_dict(self) -> PollAnswerPayload: + return { + 'poll_media': self.media.to_dict(), + } + + async def voters( + self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None + ) -> AsyncIterator[Union[User, Member]]: + """Returns an :term:`asynchronous iterator` representing the users that have voted on this answer. + + The ``after`` parameter must represent a user + and meet the :class:`abc.Snowflake` abc. + + This can only be called when the parent poll was sent to a message. + + Examples + -------- + + Usage :: + + async for voter in poll_answer.voters(): + print(f'{voter} has voted for {poll_answer}!') + + Flattening into a list: :: + + voters = [voter async for voter in poll_answer.voters()] + # voters is now a list of User + + Parameters + ---------- + limit: Optional[:class:`int`] + The maximum number of results to return. + If not provided, returns all the users who + voted on this poll answer. + after: Optional[:class:`abc.Snowflake`] + For pagination, voters are sorted by member. + + Raises + ------ + HTTPException + Retrieving the users failed. + + Yields + ------ + Union[:class:`User`, :class:`Member`] + The member (if retrievable) or the user that has voted + on this poll answer. The case where it can be a :class:`Member` + is in a guild message context. Sometimes it can be a :class:`User` + if the member has left the guild or if the member is not cached. + """ + + if not self._message or not self._state: # Make type checker happy + raise ClientException('You cannot fetch users to a poll not sent with a message') + + if limit is None: + if not self._message.poll: + limit = 100 + else: + limit = self.vote_count or 100 + + while limit > 0: + retrieve = min(limit, 100) + + message = self._message + guild = self._message.guild + state = self._state + after_id = after.id if after else None + + data = await state.http.get_poll_answer_voters( + message.channel.id, message.id, self.id, after=after_id, limit=retrieve + ) + users = data['users'] + + if len(users) == 0: + # No more voters to fetch, terminate loop + break + + limit -= len(users) + after = Object(id=int(users[-1]['id'])) + + if not guild or isinstance(guild, Object): + for raw_user in reversed(users): + yield User(state=self._state, data=raw_user) + continue + + for raw_member in reversed(users): + member_id = int(raw_member['id']) + member = guild.get_member(member_id) + + yield member or User(state=self._state, data=raw_member) + + +class Poll: + """Represents a message's Poll. + + .. versionadded:: 2.4 + + Parameters + ---------- + question: Union[:class:`PollMedia`, :class:`str`] + The poll's displayed question. The text can be up to 300 characters. + duration: :class:`datetime.timedelta` + The duration of the poll. Duration must be in hours. + multiple: :class:`bool` + Whether users are allowed to select more than one answer. + Defaultsto ``False``. + layout_type: :class:`PollLayoutType` + The layout type of the poll. Defaults to :attr:`PollLayoutType.default`. + """ + + __slots__ = ( + 'multiple', + '_answers', + 'duration', + 'layout_type', + '_question_media', + '_message', + '_expiry', + '_finalized', + '_state', + ) + + def __init__( + self, + question: Union[PollMedia, str], + duration: datetime.timedelta, + *, + multiple: bool = False, + layout_type: PollLayoutType = PollLayoutType.default, + ) -> None: + self._question_media: PollMedia = PollMedia(text=question, emoji=None) if isinstance(question, str) else question + self._answers: Dict[int, PollAnswer] = {} + self.duration: datetime.timedelta = duration + + self.multiple: bool = multiple + self.layout_type: PollLayoutType = layout_type + + # NOTE: These attributes are set manually when calling + # _from_data, so it should be ``None`` now. + self._message: Optional[Message] = None + self._state: Optional[ConnectionState] = None + self._finalized: bool = False + self._expiry: Optional[datetime.datetime] = None + + def _update(self, message: Message) -> None: + self._state = message._state + self._message = message + + if not message.poll: + return + + # The message's poll contains the more up to date data. + self._expiry = message.poll.expires_at + self._finalized = message.poll._finalized + + def _update_results(self, data: PollResultPayload) -> None: + self._finalized = data['is_finalized'] + + for count in data['answer_counts']: + answer = self.get_answer(int(count['id'])) + if not answer: + continue + + answer._update_with_results(count) + + def _handle_vote(self, answer_id: int, added: bool, self_voted: bool = False): + answer = self.get_answer(answer_id) + if not answer: + return + + answer._handle_vote_event(added, self_voted) + + @classmethod + def _from_data(cls, *, data: PollPayload, message: Message, state: ConnectionState) -> Self: + multiselect = data.get('allow_multiselect', False) + layout_type = try_enum(PollLayoutType, data.get('layout_type', 1)) + question_data = data.get('question') + question = question_data.get('text') + expiry = utils.parse_time(data['expiry']) # If obtained via API, then expiry is set. + duration = expiry - message.created_at + # self.created_at = message.created_at + # duration = self.created_at - expiry + + if (duration.total_seconds() / 3600) > 168: # As the duration may exceed little milliseconds then we fix it + duration = datetime.timedelta(days=7) + + self = cls( + duration=duration, + multiple=multiselect, + layout_type=layout_type, + question=question, + ) + self._answers = { + int(answer['answer_id']): PollAnswer(data=answer, message=message, poll=self) for answer in data['answers'] + } + self._message = message + self._state = state + self._expiry = expiry + + try: + self._update_results(data['results']) + except KeyError: + pass + + return self + + def _to_dict(self) -> PollCreatePayload: + data: PollCreatePayload = { + 'allow_multiselect': self.multiple, + 'question': self._question_media.to_dict(), + 'duration': self.duration.total_seconds() / 3600, + 'layout_type': self.layout_type.value, + 'answers': [answer._to_dict() for answer in self.answers], + } + return data + + def __repr__(self) -> str: + return f"" + + @property + def question(self) -> str: + """:class:`str`: Returns this poll answer question string.""" + return self._question_media.text + + @property + def answers(self) -> List[PollAnswer]: + """List[:class:`PollAnswer`]: Returns a read-only copy of the answers""" + return list(self._answers.values()) + + @property + def expires_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: A datetime object representing the poll expiry. + + .. note:: + + This will **always** be ``None`` for stateless polls. + """ + return self._expiry + + @property + def created_at(self) -> Optional[datetime.datetime]: + """:class:`datetime.datetime`: Returns the poll's creation time, or ``None`` if user-created.""" + + if not self._message: + return + return self._message.created_at + + @property + def message(self) -> Optional[Message]: + """:class:`Message`: The message this poll is from.""" + return self._message + + @property + def total_votes(self) -> int: + """:class:`int`: Returns the sum of all the answer votes.""" + return sum([answer.vote_count for answer in self.answers]) + + def is_finalised(self) -> bool: + """:class:`bool`: Returns whether the poll has finalised. + + This always returns ``False`` for stateless polls. + """ + return self._finalized + + is_finalized = is_finalised + + def copy(self) -> Self: + """Returns a stateless copy of this poll. + + This is meant to be used when you want to edit a stateful poll. + + Returns + ------- + :class:`Poll` + The copy of the poll. + """ + + new = self.__class__(question=self.question, duration=self.duration) + + # We want to return a stateless copy of the poll, so we should not + # override new._answers as our answers may contain a state + for answer in self.answers: + new.add_answer(text=answer.text, emoji=answer.emoji) + + return new + + def add_answer( + self, + *, + text: str, + emoji: Optional[Union[PartialEmoji, Emoji, str]] = None, + ) -> Self: + """Appends a new answer to this poll. + + Parameters + ---------- + text: :class:`str` + The text label for this poll answer. Can be up to 55 + characters. + emoji: Union[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`] + The emoji to display along the text. + + Raises + ------ + ClientException + Cannot append answers to a poll that is active. + + Returns + ------- + :class:`Poll` + This poll with the new answer appended. This allows fluent-style chaining. + """ + + if self._message: + raise ClientException('Cannot append answers to a poll that is active') + + answer = PollAnswer.from_params(id=len(self.answers) + 1, text=text, emoji=emoji, message=self._message, poll=self) + self._answers[answer.id] = answer + return self + + def get_answer( + self, + /, + id: int, + ) -> Optional[PollAnswer]: + """Returns the answer with the provided ID or ``None`` if not found. + + Parameters + ---------- + id: :class:`int` + The ID of the answer to get. + + Returns + ------- + Optional[:class:`PollAnswer`] + The answer. + """ + + return self._answers.get(id) + + async def end(self) -> Self: + """|coro| + + Ends the poll. + + Raises + ------ + ClientException + This poll has no attached message. + HTTPException + Ending the poll failed. + + Returns + ------- + :class:`Poll` + The updated poll. + """ + + if not self._message or not self._state: # Make type checker happy + raise ClientException('This poll has no attached message.') + + self._message = await self._message.end_poll() + + return self diff --git a/discord/raw_models.py b/discord/raw_models.py index 2fd94539e..571be38f1 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: ThreadMembersUpdate, TypingStartEvent, GuildMemberRemoveEvent, + PollVoteActionEvent, ) from .types.command import GuildApplicationCommandPermissions from .message import Message @@ -77,6 +78,7 @@ __all__ = ( 'RawTypingEvent', 'RawMemberRemoveEvent', 'RawAppCommandPermissionsUpdateEvent', + 'RawPollVoteActionEvent', ) @@ -519,3 +521,33 @@ class RawAppCommandPermissionsUpdateEvent(_RawReprMixin): self.permissions: List[AppCommandPermissions] = [ AppCommandPermissions(data=perm, guild=self.guild, state=state) for perm in data['permissions'] ] + + +class RawPollVoteActionEvent(_RawReprMixin): + """Represents the payload for a :func:`on_raw_poll_vote_add` or :func:`on_raw_poll_vote_remove` + event. + + .. versionadded:: 2.4 + + Attributes + ---------- + user_id: :class:`int` + The ID of the user that added or removed a vote. + channel_id: :class:`int` + The channel ID where the poll vote action took place. + message_id: :class:`int` + The message ID that contains the poll the user added or removed their vote on. + guild_id: Optional[:class:`int`] + The guild ID where the vote got added or removed, if applicable.. + answer_id: :class:`int` + The poll answer's ID the user voted on. + """ + + __slots__ = ('user_id', 'channel_id', 'message_id', 'guild_id', 'answer_id') + + def __init__(self, data: PollVoteActionEvent) -> None: + self.user_id: int = int(data['user_id']) + self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data['message_id']) + self.guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') + self.answer_id: int = int(data['answer_id']) diff --git a/discord/state.py b/discord/state.py index a966cb667..032dc2645 100644 --- a/discord/state.py +++ b/discord/state.py @@ -89,6 +89,7 @@ if TYPE_CHECKING: from .ui.item import Item from .ui.dynamic import DynamicItem from .app_commands import CommandTree, Translator + from .poll import Poll from .types.automod import AutoModerationRule, AutoModerationActionExecution from .types.snowflake import Snowflake @@ -509,6 +510,12 @@ class ConnectionState(Generic[ClientT]): def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None + def _get_poll(self, msg_id: Optional[int]) -> Optional[Poll]: + message = self._get_message(msg_id) + if not message: + return + return message.poll + def _add_guild_from_data(self, data: GuildPayload) -> Guild: guild = Guild(data=data, state=self) self._add_guild(guild) @@ -533,6 +540,13 @@ class ConnectionState(Generic[ClientT]): return channel or PartialMessageable(state=self, guild_id=guild_id, id=channel_id), guild + def _update_poll_counts(self, message: Message, answer_id: int, added: bool, self_voted: bool = False) -> Optional[Poll]: + poll = message.poll + if not poll: + return + poll._handle_vote(answer_id, added, self_voted) + return poll + async def chunker( self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None ) -> None: @@ -1619,6 +1633,52 @@ class ConnectionState(Generic[ClientT]): entitlement = Entitlement(data=data, state=self) self.dispatch('entitlement_delete', entitlement) + def parse_message_poll_vote_add(self, data: gw.PollVoteActionEvent) -> None: + raw = RawPollVoteActionEvent(data) + + self.dispatch('raw_poll_vote_add', raw) + + message = self._get_message(raw.message_id) + guild = self._get_guild(raw.guild_id) + + if guild: + user = guild.get_member(raw.user_id) + else: + user = self.get_user(raw.user_id) + + if message and user: + poll = self._update_poll_counts(message, raw.answer_id, True, raw.user_id == self.self_id) + if not poll: + _log.warning( + 'POLL_VOTE_ADD referencing message with ID: %s does not have a poll. Discarding.', raw.message_id + ) + return + + self.dispatch('poll_vote_add', user, poll.get_answer(raw.answer_id)) + + def parse_message_poll_vote_remove(self, data: gw.PollVoteActionEvent) -> None: + raw = RawPollVoteActionEvent(data) + + self.dispatch('raw_poll_vote_remove', raw) + + message = self._get_message(raw.message_id) + guild = self._get_guild(raw.guild_id) + + if guild: + user = guild.get_member(raw.user_id) + else: + user = self.get_user(raw.user_id) + + if message and user: + poll = self._update_poll_counts(message, raw.answer_id, False, raw.user_id == self.self_id) + if not poll: + _log.warning( + 'POLL_VOTE_REMOVE referencing message with ID: %s does not have a poll. Discarding.', raw.message_id + ) + return + + self.dispatch('poll_vote_remove', user, poll.get_answer(raw.answer_id)) + def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, (TextChannel, Thread, VoiceChannel)): return channel.guild.get_member(user_id) diff --git a/discord/types/gateway.py b/discord/types/gateway.py index c0908435f..b79bd9ca9 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -352,3 +352,11 @@ class GuildAuditLogEntryCreate(AuditLogEntry): EntitlementCreateEvent = EntitlementUpdateEvent = EntitlementDeleteEvent = Entitlement + + +class PollVoteActionEvent(TypedDict): + user_id: Snowflake + channel_id: Snowflake + message_id: Snowflake + guild_id: NotRequired[Snowflake] + answer_id: int diff --git a/discord/types/message.py b/discord/types/message.py index 35d80be42..16912d628 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -37,6 +37,7 @@ from .components import Component from .interactions import MessageInteraction, MessageInteractionMetadata from .sticker import StickerItem from .threads import Thread +from .poll import Poll class PartialMessage(TypedDict): @@ -163,6 +164,7 @@ class Message(PartialMessage): attachments: List[Attachment] embeds: List[Embed] pinned: bool + poll: NotRequired[Poll] type: MessageType member: NotRequired[Member] mention_channels: NotRequired[List[ChannelMention]] diff --git a/discord/types/poll.py b/discord/types/poll.py new file mode 100644 index 000000000..fabdbd48f --- /dev/null +++ b/discord/types/poll.py @@ -0,0 +1,88 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + + +from typing import List, TypedDict, Optional, Literal, TYPE_CHECKING +from typing_extensions import NotRequired + +from .snowflake import Snowflake + +if TYPE_CHECKING: + from .user import User + from .emoji import PartialEmoji + + +LayoutType = Literal[1] # 1 = Default + + +class PollMedia(TypedDict): + text: str + emoji: NotRequired[Optional[PartialEmoji]] + + +class PollAnswer(TypedDict): + poll_media: PollMedia + + +class PollAnswerWithID(PollAnswer): + answer_id: int + + +class PollAnswerCount(TypedDict): + id: Snowflake + count: int + me_voted: bool + + +class PollAnswerVoters(TypedDict): + users: List[User] + + +class PollResult(TypedDict): + is_finalized: bool + answer_counts: List[PollAnswerCount] + + +class PollCreate(TypedDict): + allow_multiselect: bool + answers: List[PollAnswer] + duration: float + layout_type: LayoutType + question: PollMedia + + +# We don't subclass Poll as it will +# still have the duration field, which +# is converted into expiry when poll is +# fetched from a message or returned +# by a `send` method in a Messageable +class Poll(TypedDict): + allow_multiselect: bool + answers: List[PollAnswerWithID] + expiry: str + layout_type: LayoutType + question: PollMedia + results: PollResult diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 767db38cc..d04e21b57 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -72,6 +72,7 @@ if TYPE_CHECKING: from ..channel import VoiceChannel from ..abc import Snowflake from ..ui.view import View + from ..poll import Poll import datetime from ..types.webhook import ( Webhook as WebhookPayload, @@ -541,6 +542,7 @@ def interaction_message_response_params( view: Optional[View] = MISSING, allowed_mentions: Optional[AllowedMentions] = MISSING, previous_allowed_mentions: Optional[AllowedMentions] = None, + poll: Poll = MISSING, ) -> MultipartParameters: if files is not MISSING and file is not MISSING: raise TypeError('Cannot mix file and files keyword arguments.') @@ -608,6 +610,9 @@ def interaction_message_response_params( data['attachments'] = attachments_payload + if poll is not MISSING: + data['poll'] = poll._to_dict() + multipart = [] if files: data = {'type': type, 'data': data} @@ -1597,6 +1602,7 @@ class Webhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> WebhookMessage: ... @@ -1621,6 +1627,7 @@ class Webhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> None: ... @@ -1644,6 +1651,7 @@ class Webhook(BaseWebhook): suppress_embeds: bool = False, silent: bool = False, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> Optional[WebhookMessage]: """|coro| @@ -1734,6 +1742,15 @@ class Webhook(BaseWebhook): .. versionadded:: 2.4 + poll: :class:`Poll` + The poll to send with this message. + + .. warning:: + + When sending a Poll via webhook, you cannot manually end it. + + .. versionadded:: 2.4 + Raises -------- HTTPException @@ -1811,6 +1828,7 @@ class Webhook(BaseWebhook): allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, applied_tags=applied_tag_ids, + poll=poll, ) as params: adapter = async_context.get() thread_id: Optional[int] = None @@ -1838,6 +1856,9 @@ class Webhook(BaseWebhook): message_id = None if msg is None else msg.id self._state.store_view(view, message_id) + if poll is not MISSING and msg: + poll._update(msg) + return msg async def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> WebhookMessage: diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index 198cdf53b..cf23e977b 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: from ..file import File from ..embeds import Embed + from ..poll import Poll from ..mentions import AllowedMentions from ..message import Attachment from ..abc import Snowflake @@ -872,6 +873,7 @@ class SyncWebhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> SyncWebhookMessage: ... @@ -894,6 +896,7 @@ class SyncWebhook(BaseWebhook): suppress_embeds: bool = MISSING, silent: bool = MISSING, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> None: ... @@ -915,6 +918,7 @@ class SyncWebhook(BaseWebhook): suppress_embeds: bool = False, silent: bool = False, applied_tags: List[ForumTag] = MISSING, + poll: Poll = MISSING, ) -> Optional[SyncWebhookMessage]: """Sends a message using the webhook. @@ -979,6 +983,14 @@ class SyncWebhook(BaseWebhook): in the UI, but will not actually send a notification. .. versionadded:: 2.2 + poll: :class:`Poll` + The poll to send with this message. + + .. warning:: + + When sending a Poll via webhook, you cannot manually end it. + + .. versionadded:: 2.4 Raises -------- @@ -1037,6 +1049,7 @@ class SyncWebhook(BaseWebhook): previous_allowed_mentions=previous_mentions, flags=flags, applied_tags=applied_tag_ids, + poll=poll, ) as params: adapter: WebhookAdapter = _get_webhook_adapter() thread_id: Optional[int] = None @@ -1054,8 +1067,15 @@ class SyncWebhook(BaseWebhook): wait=wait, ) + msg = None + if wait: - return self._create_message(data, thread=thread) + msg = self._create_message(data, thread=thread) + + if poll is not MISSING and msg: + poll._update(msg) + + return msg def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> SyncWebhookMessage: """Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook. diff --git a/docs/api.rst b/docs/api.rst index b4285c3c1..13b49df5b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1047,6 +1047,42 @@ Messages :param payload: The raw event payload data. :type payload: :class:`RawBulkMessageDeleteEvent` +Polls +~~~~~~ + +.. function:: on_poll_vote_add(user, answer) + on_poll_vote_remove(user, answer) + + Called when a :class:`Poll` gains or loses a vote. If the ``user`` or ``message`` + are not cached then this event will not be called. + + This requires :attr:`Intents.message_content` and :attr:`Intents.polls` to be enabled. + + .. note:: + + If the poll allows multiple answers and the user removes or adds multiple votes, this + event will be called as many times as votes that are added or removed. + + .. versionadded:: 2.4 + + :param user: The user that performed the action. + :type user: Union[:class:`User`, :class:`Member`] + :param answer: The answer the user voted or removed their vote from. + :type answer: :class:`PollAnswer` + +.. function:: on_raw_poll_vote_add(payload) + on_raw_poll_vote_remove(payload) + + Called when a :class:`Poll` gains or loses a vote. Unlike :func:`on_poll_vote_add` and :func:`on_poll_vote_remove` + this is called regardless of the state of the internal user and message cache. + + This requires :attr:`Intents.message_content` and :attr:`Intents.polls` to be enabled. + + .. versionadded:: 2.4 + + :param payload: The raw event payload data. + :type payload: :class:`RawPollVoteActionEvent` + Reactions ~~~~~~~~~~ @@ -3577,6 +3613,16 @@ of :class:`enum.Enum`. The entitlement owner is a user. +.. class:: PollLayoutType + + Represents how a poll answers are shown + + .. versionadded:: 2.4 + + .. attribute:: default + + The default layout. + .. _discord-api-audit-logs: Audit Log Data @@ -5007,6 +5053,14 @@ RawAppCommandPermissionsUpdateEvent .. autoclass:: RawAppCommandPermissionsUpdateEvent() :members: +RawPollVoteActionEvent +~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: RawPollVoteActionEvent + +.. autoclass:: RawPollVoteActionEvent() + :members: + PartialWebhookGuild ~~~~~~~~~~~~~~~~~~~~ @@ -5288,6 +5342,25 @@ ForumTag .. autoclass:: ForumTag :members: +Poll +~~~~ + +.. attributetable:: Poll + +.. autoclass:: Poll() + :members: + +.. attributetable:: PollAnswer + +.. autoclass:: PollAnswer() + :members: + :inherited-members: + +.. attributetable:: PollMedia + +.. autoclass:: PollMedia() + :members: + Exceptions ------------