diff --git a/README.rst b/README.rst index 621a69500..b2112f9d6 100644 --- a/README.rst +++ b/README.rst @@ -27,6 +27,13 @@ Installing To install the library without full voice support, you can just run the following command: +.. note:: + + A `Virtual Environment `__ is recommended to install + the library, especially on Linux where the system Python is externally managed and restricts which + packages you can install on it. + + .. code:: sh # Linux/macOS diff --git a/discord/__init__.py b/discord/__init__.py index 765719b68..c206f650f 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -70,6 +70,8 @@ from .components import * from .threads import * from .automod import * from .poll import * +from .soundboard import * +from .subscription import * class VersionInfo(NamedTuple): @@ -84,4 +86,10 @@ version_info: VersionInfo = VersionInfo(major=2, minor=5, micro=0, releaselevel= logging.getLogger(__name__).addHandler(logging.NullHandler()) +# This is a backwards compatibility hack and should be removed in v3 +# Essentially forcing the exception to have different base classes +# In the future, this should only inherit from ClientException +if len(MissingApplicationID.__bases__) == 1: + MissingApplicationID.__bases__ = (app_commands.AppCommandError, ClientException) + del logging, NamedTuple, Literal, VersionInfo diff --git a/discord/abc.py b/discord/abc.py index 7f10811c4..891404b33 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,6 +26,7 @@ from __future__ import annotations import copy import time +import secrets import asyncio from datetime import datetime from typing import ( @@ -1005,11 +1006,15 @@ class GuildChannel: base_attrs: Dict[str, Any], *, name: Optional[str] = None, + category: Optional[CategoryChannel] = None, reason: Optional[str] = None, ) -> Self: base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites] base_attrs['parent_id'] = self.category_id base_attrs['name'] = name or self.name + if category is not None: + base_attrs['parent_id'] = category.id + guild_id = self.guild.id cls = self.__class__ data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) @@ -1019,7 +1024,13 @@ class GuildChannel: self.guild._channels[obj.id] = obj # type: ignore # obj is a GuildChannel return obj - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> Self: + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> Self: """|coro| Clones this channel. This creates a channel with the same properties @@ -1034,6 +1045,11 @@ class GuildChannel: name: Optional[:class:`str`] The name of the new channel. If not provided, defaults to this channel name. + category: Optional[:class:`~discord.CategoryChannel`] + The category the new channel belongs to. + This parameter is ignored if cloning a category channel. + + .. versionadded:: 2.5 reason: Optional[:class:`str`] The reason for cloning this channel. Shows up on the audit log. @@ -1515,10 +1531,11 @@ class Messageable: .. versionadded:: 1.4 reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`, :class:`~discord.PartialMessage`] - A reference to the :class:`~discord.Message` to which you are replying, this can be created using - :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control - whether this mentions the author of the referenced message using the :attr:`~discord.AllowedMentions.replied_user` - attribute of ``allowed_mentions`` or by setting ``mention_author``. + A reference to the :class:`~discord.Message` to which you are referencing, this can be created using + :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. + In the event of a replying reference, you can control whether this mentions the author of the referenced + message using the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions`` or by + setting ``mention_author``. .. versionadded:: 1.6 @@ -1598,6 +1615,9 @@ class Messageable: else: flags = MISSING + if nonce is None: + nonce = secrets.randbits(64) + with handle_message_parameters( content=content, tts=tts, diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py index cd6eafaf3..a872fb4be 100644 --- a/discord/app_commands/commands.py +++ b/discord/app_commands/commands.py @@ -2821,7 +2821,7 @@ def allowed_installs( return inner -def default_permissions(**perms: bool) -> Callable[[T], T]: +def default_permissions(perms_obj: Optional[Permissions] = None, /, **perms: bool) -> Callable[[T], T]: r"""A decorator that sets the default permissions needed to execute this command. When this decorator is used, by default users must have these permissions to execute the command. @@ -2845,8 +2845,12 @@ def default_permissions(**perms: bool) -> Callable[[T], T]: ----------- \*\*perms: :class:`bool` Keyword arguments denoting the permissions to set as the default. + perms_obj: :class:`~discord.Permissions` + A permissions object as positional argument. This can be used in combination with ``**perms``. - Example + .. versionadded:: 2.5 + + Examples --------- .. code-block:: python3 @@ -2855,9 +2859,21 @@ def default_permissions(**perms: bool) -> Callable[[T], T]: @app_commands.default_permissions(manage_messages=True) async def test(interaction: discord.Interaction): await interaction.response.send_message('You may or may not have manage messages.') + + .. code-block:: python3 + + ADMIN_PERMS = discord.Permissions(administrator=True) + + @app_commands.command() + @app_commands.default_permissions(ADMIN_PERMS, manage_messages=True) + async def test(interaction: discord.Interaction): + await interaction.response.send_message('You may or may not have manage messages.') """ - permissions = Permissions(**perms) + if perms_obj is not None: + permissions = perms_obj | Permissions(**perms) + else: + permissions = Permissions(**perms) def decorator(func: T) -> T: if isinstance(func, (Command, Group, ContextMenu)): diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py index abdc9f2f0..2efb4e5b0 100644 --- a/discord/app_commands/errors.py +++ b/discord/app_commands/errors.py @@ -27,7 +27,7 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING, List, Optional, Sequence, Union from ..enums import AppCommandOptionType, AppCommandType, Locale -from ..errors import DiscordException, HTTPException, _flatten_error_dict +from ..errors import DiscordException, HTTPException, _flatten_error_dict, MissingApplicationID as MissingApplicationID from ..utils import _human_join __all__ = ( @@ -59,11 +59,6 @@ if TYPE_CHECKING: CommandTypes = Union[Command[Any, ..., Any], Group, ContextMenu] -APP_ID_NOT_FOUND = ( - 'Client does not have an application_id set. Either the function was called before on_ready ' - 'was called or application_id was not passed to the Client constructor.' -) - class AppCommandError(DiscordException): """The base exception type for all application command related errors. @@ -422,19 +417,6 @@ class CommandSignatureMismatch(AppCommandError): super().__init__(msg) -class MissingApplicationID(AppCommandError): - """An exception raised when the client does not have an application ID set. - An application ID is required for syncing application commands. - - This inherits from :exc:`~discord.app_commands.AppCommandError`. - - .. versionadded:: 2.0 - """ - - def __init__(self, message: Optional[str] = None): - super().__init__(message or APP_ID_NOT_FOUND) - - def _get_command_error( index: str, inner: Any, diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py index d012c52b9..e7b001727 100644 --- a/discord/app_commands/transformers.py +++ b/discord/app_commands/transformers.py @@ -34,6 +34,7 @@ from typing import ( ClassVar, Coroutine, Dict, + Generic, List, Literal, Optional, @@ -56,6 +57,7 @@ from ..user import User from ..role import Role from ..member import Member from ..message import Attachment +from .._types import ClientT __all__ = ( 'Transformer', @@ -191,7 +193,7 @@ class CommandParameter: return self.name if self._rename is MISSING else str(self._rename) -class Transformer: +class Transformer(Generic[ClientT]): """The base class that allows a type annotation in an application command parameter to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one from this type. @@ -304,7 +306,7 @@ class Transformer: else: return name - async def transform(self, interaction: Interaction, value: Any, /) -> Any: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: """|maybecoro| Transforms the converted option value into another value. @@ -324,7 +326,7 @@ class Transformer: raise NotImplementedError('Derived classes need to implement this.') async def autocomplete( - self, interaction: Interaction, value: Union[int, float, str], / + self, interaction: Interaction[ClientT], value: Union[int, float, str], / ) -> List[Choice[Union[int, float, str]]]: """|coro| @@ -352,7 +354,7 @@ class Transformer: raise NotImplementedError('Derived classes can implement this.') -class IdentityTransformer(Transformer): +class IdentityTransformer(Transformer[ClientT]): def __init__(self, type: AppCommandOptionType) -> None: self._type = type @@ -360,7 +362,7 @@ class IdentityTransformer(Transformer): def type(self) -> AppCommandOptionType: return self._type - async def transform(self, interaction: Interaction, value: Any, /) -> Any: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: return value @@ -489,7 +491,7 @@ class EnumNameTransformer(Transformer): return self._enum[value] -class InlineTransformer(Transformer): +class InlineTransformer(Transformer[ClientT]): def __init__(self, annotation: Any) -> None: super().__init__() self.annotation: Any = annotation @@ -502,7 +504,7 @@ class InlineTransformer(Transformer): def type(self) -> AppCommandOptionType: return AppCommandOptionType.string - async def transform(self, interaction: Interaction, value: Any, /) -> Any: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: return await self.annotation.transform(interaction, value) @@ -611,18 +613,18 @@ else: return transformer -class MemberTransformer(Transformer): +class MemberTransformer(Transformer[ClientT]): @property def type(self) -> AppCommandOptionType: return AppCommandOptionType.user - async def transform(self, interaction: Interaction, value: Any, /) -> Member: + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Member: if not isinstance(value, Member): raise TransformerError(value, self.type, self) return value -class BaseChannelTransformer(Transformer): +class BaseChannelTransformer(Transformer[ClientT]): def __init__(self, *channel_types: Type[Any]) -> None: super().__init__() if len(channel_types) == 1: @@ -654,22 +656,22 @@ class BaseChannelTransformer(Transformer): def channel_types(self) -> List[ChannelType]: return self._channel_types - async def transform(self, interaction: Interaction, value: Any, /): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): resolved = value.resolve() if resolved is None or not isinstance(resolved, self._types): raise TransformerError(value, AppCommandOptionType.channel, self) return resolved -class RawChannelTransformer(BaseChannelTransformer): - async def transform(self, interaction: Interaction, value: Any, /): +class RawChannelTransformer(BaseChannelTransformer[ClientT]): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): if not isinstance(value, self._types): raise TransformerError(value, AppCommandOptionType.channel, self) return value -class UnionChannelTransformer(BaseChannelTransformer): - async def transform(self, interaction: Interaction, value: Any, /): +class UnionChannelTransformer(BaseChannelTransformer[ClientT]): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): if isinstance(value, self._types): return value diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py index bc0d68ec7..90b9a21ab 100644 --- a/discord/app_commands/tree.py +++ b/discord/app_commands/tree.py @@ -73,7 +73,7 @@ if TYPE_CHECKING: from .commands import ContextMenuCallback, CommandCallback, P, T ErrorFunc = Callable[ - [Interaction, AppCommandError], + [Interaction[ClientT], AppCommandError], Coroutine[Any, Any, Any], ] @@ -833,7 +833,7 @@ class CommandTree(Generic[ClientT]): else: _log.error('Ignoring exception in command tree', exc_info=error) - def error(self, coro: ErrorFunc) -> ErrorFunc: + def error(self, coro: ErrorFunc[ClientT]) -> ErrorFunc[ClientT]: """A decorator that registers a coroutine as a local error handler. This must match the signature of the :meth:`on_error` callback. diff --git a/discord/audit_logs.py b/discord/audit_logs.py index fc1bc298b..59d563829 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -235,6 +235,10 @@ def _transform_automod_actions(entry: AuditLogEntry, data: List[AutoModerationAc return [AutoModRuleAction.from_data(action) for action in data] +def _transform_default_emoji(entry: AuditLogEntry, data: str) -> PartialEmoji: + return PartialEmoji(name=data) + + E = TypeVar('E', bound=enums.Enum) @@ -341,6 +345,8 @@ class AuditLogChanges: 'available_tags': (None, _transform_forum_tags), 'flags': (None, _transform_overloaded_flags), 'default_reaction_emoji': (None, _transform_default_reaction), + 'emoji_name': ('emoji', _transform_default_emoji), + 'user_id': ('user', _transform_member_id) } # fmt: on diff --git a/discord/channel.py b/discord/channel.py index 55b25a03c..a306707d6 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -47,7 +47,16 @@ import datetime import discord.abc from .scheduled_event import ScheduledEvent from .permissions import PermissionOverwrite, Permissions -from .enums import ChannelType, ForumLayoutType, ForumOrderType, PrivacyLevel, try_enum, VideoQualityMode, EntityType +from .enums import ( + ChannelType, + ForumLayoutType, + ForumOrderType, + PrivacyLevel, + try_enum, + VideoQualityMode, + EntityType, + VoiceChannelEffectAnimationType, +) from .mixins import Hashable from . import utils from .utils import MISSING @@ -56,8 +65,10 @@ from .errors import ClientException from .stage_instance import StageInstance from .threads import Thread from .partial_emoji import _EmojiTag, PartialEmoji -from .flags import ChannelFlags +from .flags import ChannelFlags, MessageFlags from .http import handle_message_parameters +from .object import Object +from .soundboard import BaseSoundboardSound, SoundboardDefaultSound __all__ = ( 'TextChannel', @@ -69,6 +80,8 @@ __all__ = ( 'ForumChannel', 'GroupChannel', 'PartialMessageable', + 'VoiceChannelEffect', + 'VoiceChannelSoundEffect', ) if TYPE_CHECKING: @@ -76,7 +89,6 @@ if TYPE_CHECKING: from .types.threads import ThreadArchiveDuration from .role import Role - from .object import Object from .member import Member, VoiceState from .abc import Snowflake, SnowflakeTime from .embeds import Embed @@ -100,8 +112,11 @@ if TYPE_CHECKING: ForumChannel as ForumChannelPayload, MediaChannel as MediaChannelPayload, ForumTag as ForumTagPayload, + VoiceChannelEffect as VoiceChannelEffectPayload, ) from .types.snowflake import SnowflakeList + from .types.soundboard import BaseSoundboardSound as BaseSoundboardSoundPayload + from .soundboard import SoundboardSound OverwriteKeyT = TypeVar('OverwriteKeyT', Role, BaseUser, Object, Union[Role, Member, Object]) @@ -111,6 +126,121 @@ class ThreadWithMessage(NamedTuple): message: Message +class VoiceChannelEffectAnimation(NamedTuple): + id: int + type: VoiceChannelEffectAnimationType + + +class VoiceChannelSoundEffect(BaseSoundboardSound): + """Represents a Discord voice channel sound effect. + + .. versionadded:: 2.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two sound effects are equal. + + .. describe:: x != y + + Checks if two sound effects are not equal. + + .. describe:: hash(x) + + Returns the sound effect's hash. + + Attributes + ------------ + id: :class:`int` + The ID of the sound. + volume: :class:`float` + The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%). + """ + + __slots__ = ('_state',) + + def __init__(self, *, state: ConnectionState, id: int, volume: float): + data: BaseSoundboardSoundPayload = { + 'sound_id': id, + 'volume': volume, + } + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id} volume={self.volume}>" + + @property + def created_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: Returns the snowflake's creation time in UTC. + Returns ``None`` if it's a default sound.""" + if self.is_default(): + return None + else: + return utils.snowflake_time(self.id) + + def is_default(self) -> bool: + """:class:`bool`: Whether it's a default sound or not.""" + # if it's smaller than the Discord Epoch it cannot be a snowflake + return self.id < utils.DISCORD_EPOCH + + +class VoiceChannelEffect: + """Represents a Discord voice channel effect. + + .. versionadded:: 2.5 + + Attributes + ------------ + channel: :class:`VoiceChannel` + The channel in which the effect is sent. + user: Optional[:class:`Member`] + The user who sent the effect. ``None`` if not found in cache. + animation: Optional[:class:`VoiceChannelEffectAnimation`] + The animation the effect has. Returns ``None`` if the effect has no animation. + emoji: Optional[:class:`PartialEmoji`] + The emoji of the effect. + sound: Optional[:class:`VoiceChannelSoundEffect`] + The sound of the effect. Returns ``None`` if it's an emoji effect. + """ + + __slots__ = ('channel', 'user', 'animation', 'emoji', 'sound') + + def __init__(self, *, state: ConnectionState, data: VoiceChannelEffectPayload, guild: Guild): + self.channel: VoiceChannel = guild.get_channel(int(data['channel_id'])) # type: ignore # will always be a VoiceChannel + self.user: Optional[Member] = guild.get_member(int(data['user_id'])) + self.animation: Optional[VoiceChannelEffectAnimation] = None + + animation_id = data.get('animation_id') + if animation_id is not None: + animation_type = try_enum(VoiceChannelEffectAnimationType, data['animation_type']) # type: ignore # cannot be None here + self.animation = VoiceChannelEffectAnimation(id=animation_id, type=animation_type) + + emoji = data.get('emoji') + self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None + self.sound: Optional[VoiceChannelSoundEffect] = None + + sound_id: Optional[int] = utils._get_as_snowflake(data, 'sound_id') + if sound_id is not None: + sound_volume = data.get('sound_volume') or 0.0 + self.sound = VoiceChannelSoundEffect(state=state, id=sound_id, volume=sound_volume) + + def __repr__(self) -> str: + attrs = [ + ('channel', self.channel), + ('user', self.user), + ('animation', self.animation), + ('emoji', self.emoji), + ('sound', self.sound), + ] + inner = ' '.join('%s=%r' % t for t in attrs) + return f"<{self.__class__.__name__} {inner}>" + + def is_sound(self) -> bool: + """:class:`bool`: Whether the effect is a sound or not.""" + return self.sound is not None + + class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. @@ -395,9 +525,26 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel: + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> TextChannel: + base: Dict[Any, Any] = { + 'topic': self.topic, + 'nsfw': self.nsfw, + 'default_auto_archive_duration': self.default_auto_archive_duration, + 'default_thread_rate_limit_per_user': self.default_thread_slowmode_delay, + } + if not self.is_news(): + base['rate_limit_per_user'] = self.slowmode_delay return await self._clone_impl( - {'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason + base, + name=name, + category=category, + reason=reason, ) async def delete_messages(self, messages: Iterable[Snowflake], *, reason: Optional[str] = None) -> None: @@ -1249,6 +1396,27 @@ class VocalGuildChannel(discord.abc.Messageable, discord.abc.Connectable, discor data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, *, name: Optional[str] = None, category: Optional[CategoryChannel] = None, reason: Optional[str] = None + ) -> Self: + base = { + 'bitrate': self.bitrate, + 'user_limit': self.user_limit, + 'rate_limit_per_user': self.slowmode_delay, + 'nsfw': self.nsfw, + 'video_quality_mode': self.video_quality_mode.value, + } + if self.rtc_region: + base['rtc_region'] = self.rtc_region + + return await self._clone_impl( + base, + name=name, + category=category, + reason=reason, + ) + class VoiceChannel(VocalGuildChannel): """Represents a Discord guild voice channel. @@ -1343,10 +1511,6 @@ class VoiceChannel(VocalGuildChannel): """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.voice - @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: - return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason) - @overload async def edit(self) -> None: ... @@ -1456,6 +1620,35 @@ class VoiceChannel(VocalGuildChannel): # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + async def send_sound(self, sound: Union[SoundboardSound, SoundboardDefaultSound], /) -> None: + """|coro| + + Sends a soundboard sound for this channel. + + You must have :attr:`~Permissions.speak` and :attr:`~Permissions.use_soundboard` to do this. + Additionally, you must have :attr:`~Permissions.use_external_sounds` if the sound is from + a different guild. + + .. versionadded:: 2.5 + + Parameters + ----------- + sound: Union[:class:`SoundboardSound`, :class:`SoundboardDefaultSound`] + The sound to send for this channel. + + Raises + ------- + Forbidden + You do not have permissions to send a sound for this channel. + HTTPException + Sending the sound failed. + """ + payload = {'sound_id': sound.id} + if not isinstance(sound, SoundboardDefaultSound) and self.guild.id != sound.guild.id: + payload['source_guild_id'] = sound.guild.id + + await self._state.http.send_soundboard_sound(self.id, **payload) + class StageChannel(VocalGuildChannel): """Represents a Discord guild stage channel. @@ -1588,10 +1781,6 @@ class StageChannel(VocalGuildChannel): """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.stage_voice - @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel: - return await self._clone_impl({}, name=name, reason=reason) - @property def instance(self) -> Optional[StageInstance]: """Optional[:class:`StageInstance`]: The running stage instance of the stage channel. @@ -1869,7 +2058,13 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel: + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> CategoryChannel: return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) @overload @@ -2297,6 +2492,14 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): def _sorting_bucket(self) -> int: return ChannelType.text.value + @property + def members(self) -> List[Member]: + """List[:class:`Member`]: Returns all members that can see this channel. + + .. versionadded:: 2.5 + """ + return [m for m in self.guild.members if self.permissions_for(m).read_messages] + @property def _scheduled_event_entity_type(self) -> Optional[EntityType]: return None @@ -2386,9 +2589,33 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): return self._type == ChannelType.media.value @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> ForumChannel: + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel], + reason: Optional[str] = None, + ) -> ForumChannel: + base = { + 'topic': self.topic, + 'rate_limit_per_user': self.slowmode_delay, + 'nsfw': self.nsfw, + 'default_auto_archive_duration': self.default_auto_archive_duration, + 'available_tags': [tag.to_dict() for tag in self.available_tags], + 'default_thread_rate_limit_per_user': self.default_thread_slowmode_delay, + } + if self.default_sort_order: + base['default_sort_order'] = self.default_sort_order.value + if self.default_reaction_emoji: + base['default_reaction_emoji'] = self.default_reaction_emoji._to_forum_tag_payload() + if not self.is_media() and self.default_layout: + base['default_forum_layout'] = self.default_layout.value + return await self._clone_impl( - {'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason + base, + name=name, + category=category, + reason=reason, ) @overload @@ -2719,8 +2946,6 @@ class ForumChannel(discord.abc.GuildChannel, Hashable): raise TypeError(f'view parameter must be View not {view.__class__.__name__}') if suppress_embeds: - from .message import MessageFlags # circular import - flags = MessageFlags._from_value(4) else: flags = MISSING @@ -3312,6 +3537,14 @@ class PartialMessageable(discord.abc.Messageable, Hashable): return Permissions.none() + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the channel. + + .. versionadded:: 2.5 + """ + return f'<#{self.id}>' + def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. diff --git a/discord/client.py b/discord/client.py index 50f76d5a2..ef7980ec4 100644 --- a/discord/client.py +++ b/discord/client.py @@ -53,7 +53,7 @@ from .user import User, ClientUser from .invite import Invite from .template import Template from .widget import Widget -from .guild import Guild +from .guild import Guild, GuildPreview from .emoji import Emoji from .channel import _threaded_channel_factory, PartialMessageable from .enums import ChannelType, EntitlementOwnerType @@ -77,6 +77,7 @@ from .ui.dynamic import DynamicItem from .stage_instance import StageInstance from .threads import Thread from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory +from .soundboard import SoundboardDefaultSound, SoundboardSound if TYPE_CHECKING: from types import TracebackType @@ -84,7 +85,7 @@ if TYPE_CHECKING: from typing_extensions import Self from .abc import Messageable, PrivateChannel, Snowflake, SnowflakeTime - from .app_commands import Command, ContextMenu, MissingApplicationID + from .app_commands import Command, ContextMenu from .automod import AutoModAction, AutoModRule from .channel import DMChannel, GroupChannel from .ext.commands import AutoShardedBot, Bot, Context, CommandError @@ -118,6 +119,7 @@ if TYPE_CHECKING: from .voice_client import VoiceProtocol from .audit_logs import AuditLogEntry from .poll import PollAnswer + from .subscription import Subscription # fmt: off @@ -250,7 +252,7 @@ class Client: .. versionadded:: 2.0 connector: Optional[:class:`aiohttp.BaseConnector`] - The aiohhtp connector to use for this client. This can be used to control underlying aiohttp + The aiohttp connector to use for this client. This can be used to control underlying aiohttp behavior, such as setting a dns resolver or sslcontext. .. versionadded:: 2.5 @@ -383,6 +385,14 @@ class Client: """ return self._connection.stickers + @property + def soundboard_sounds(self) -> List[SoundboardSound]: + """List[:class:`.SoundboardSound`]: The soundboard sounds that the connected client has. + + .. versionadded:: 2.5 + """ + return self._connection.soundboard_sounds + @property def cached_messages(self) -> Sequence[Message]: """Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached. @@ -1109,6 +1119,23 @@ class Client: """ return self._connection.get_sticker(id) + def get_soundboard_sound(self, id: int, /) -> Optional[SoundboardSound]: + """Returns a soundboard sound with the given ID. + + .. versionadded:: 2.5 + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`.SoundboardSound`] + The soundboard sound or ``None`` if not found. + """ + return self._connection.get_soundboard_sound(id) + def get_all_channels(self) -> Generator[GuildChannel, None, None]: """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. @@ -1347,6 +1374,18 @@ class Client: ) -> Union[str, bytes]: ... + # Entitlements + @overload + async def wait_for( + self, + event: Literal['entitlement_create', 'entitlement_update', 'entitlement_delete'], + /, + *, + check: Optional[Callable[[Entitlement], bool]], + timeout: Optional[float] = None, + ) -> Entitlement: + ... + # Guilds @overload @@ -1755,6 +1794,18 @@ class Client: ) -> Coroutine[Any, Any, Tuple[StageInstance, StageInstance]]: ... + # Subscriptions + @overload + async def wait_for( + self, + event: Literal['subscription_create', 'subscription_update', 'subscription_delete'], + /, + *, + check: Optional[Callable[[Subscription], bool]], + timeout: Optional[float] = None, + ) -> Subscription: + ... + # Threads @overload async def wait_for( @@ -2305,6 +2356,29 @@ class Client: data = await self.http.get_guild(guild_id, with_counts=with_counts) return Guild(data=data, state=self._connection) + async def fetch_guild_preview(self, guild_id: int) -> GuildPreview: + """|coro| + + Retrieves a preview of a :class:`.Guild` from an ID. If the guild is discoverable, + you don't have to be a member of it. + + .. versionadded:: 2.5 + + Raises + ------ + NotFound + The guild doesn't exist, or is not discoverable and you are not in it. + HTTPException + Getting the guild failed. + + Returns + -------- + :class:`.GuildPreview` + The guild preview from the ID. + """ + data = await self.http.get_guild_preview(guild_id) + return GuildPreview(data=data, state=self._connection) + async def create_guild( self, *, @@ -2750,6 +2824,7 @@ class Client: user: Optional[Snowflake] = None, guild: Optional[Snowflake] = None, exclude_ended: bool = False, + exclude_deleted: bool = True, ) -> AsyncIterator[Entitlement]: """Retrieves an :term:`asynchronous iterator` of the :class:`.Entitlement` that applications has. @@ -2791,6 +2866,10 @@ class Client: The guild to filter by. exclude_ended: :class:`bool` Whether to exclude ended entitlements. Defaults to ``False``. + exclude_deleted: :class:`bool` + Whether to exclude deleted entitlements. Defaults to ``True``. + + .. versionadded:: 2.5 Raises ------- @@ -2827,6 +2906,7 @@ class Client: user_id=user.id if user else None, guild_id=guild.id if guild else None, exclude_ended=exclude_ended, + exclude_deleted=exclude_deleted, ) if data: @@ -2875,7 +2955,7 @@ class Client: data, state, limit = await strategy(retrieve, state, limit) # Terminate loop on next iteration; there's no data left after this - if len(data) < 1000: + if len(data) < 100: limit = 0 for e in data: @@ -2964,6 +3044,26 @@ class Client: data = await self.http.get_sticker_pack(sticker_pack_id) return StickerPack(state=self._connection, data=data) + async def fetch_soundboard_default_sounds(self) -> List[SoundboardDefaultSound]: + """|coro| + + Retrieves all default soundboard sounds. + + .. versionadded:: 2.5 + + Raises + ------- + HTTPException + Retrieving the default soundboard sounds failed. + + Returns + --------- + List[:class:`.SoundboardDefaultSound`] + All default soundboard sounds. + """ + data = await self.http.get_soundboard_default_sounds() + return [SoundboardDefaultSound(state=self._connection, data=sound) for sound in data] + async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -3100,7 +3200,7 @@ class Client: Parameters ---------- name: :class:`str` - The emoji name. Must be at least 2 characters. + The emoji name. Must be between 2 and 32 characters long. image: :class:`bytes` The :term:`py:bytes-like object` representing the image data to use. Only JPG, PNG and GIF images are supported. diff --git a/discord/emoji.py b/discord/emoji.py index e011495fd..74f344acc 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -29,7 +29,7 @@ from .asset import Asset, AssetMixin from .utils import SnowflakeList, snowflake_time, MISSING from .partial_emoji import _EmojiTag, PartialEmoji from .user import User -from .app_commands.errors import MissingApplicationID +from .errors import MissingApplicationID from .object import Object # fmt: off diff --git a/discord/enums.py b/discord/enums.py index 6b47d17b8..cce6ac01d 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -74,6 +74,9 @@ __all__ = ( 'EntitlementType', 'EntitlementOwnerType', 'PollLayoutType', + 'VoiceChannelEffectAnimationType', + 'SubscriptionStatus', + 'MessageReferenceType', 'ScheduledEventRecurrenceFrequency', 'ScheduledEventRecurrenceWeekday', ) @@ -220,6 +223,12 @@ class ChannelType(Enum): return self.name +class MessageReferenceType(Enum): + default = 0 + reply = 0 + forward = 1 + + class MessageType(Enum): default = 0 recipient_add = 1 @@ -258,6 +267,7 @@ class MessageType(Enum): guild_incident_alert_mode_disabled = 37 guild_incident_report_raid = 38 guild_incident_report_false_alarm = 39 + purchase_notification = 44 class SpeakingState(Enum): @@ -379,6 +389,9 @@ class AuditLogAction(Enum): thread_update = 111 thread_delete = 112 app_command_permission_update = 121 + soundboard_sound_create = 130 + soundboard_sound_update = 131 + soundboard_sound_delete = 132 automod_rule_create = 140 automod_rule_update = 141 automod_rule_delete = 142 @@ -449,6 +462,9 @@ class AuditLogAction(Enum): AuditLogAction.automod_timeout_member: None, AuditLogAction.creator_monetization_request_created: None, AuditLogAction.creator_monetization_terms_accepted: None, + AuditLogAction.soundboard_sound_create: AuditLogActionCategory.create, + AuditLogAction.soundboard_sound_update: AuditLogActionCategory.update, + AuditLogAction.soundboard_sound_delete: AuditLogActionCategory.delete, } # fmt: on return lookup[self] @@ -837,6 +853,17 @@ class ReactionType(Enum): burst = 1 +class VoiceChannelEffectAnimationType(Enum): + premium = 0 + basic = 1 + + +class SubscriptionStatus(Enum): + active = 0 + ending = 1 + inactive = 2 + + class ScheduledEventRecurrenceFrequency(Enum): yearly = 0 monthly = 1 diff --git a/discord/errors.py b/discord/errors.py index 6035ace7c..a40842578 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -47,6 +47,12 @@ __all__ = ( 'ConnectionClosed', 'PrivilegedIntentsRequired', 'InteractionResponded', + 'MissingApplicationID', +) + +APP_ID_NOT_FOUND = ( + 'Client does not have an application_id set. Either the function was called before on_ready ' + 'was called or application_id was not passed to the Client constructor.' ) @@ -278,3 +284,22 @@ class InteractionResponded(ClientException): def __init__(self, interaction: Interaction): self.interaction: Interaction = interaction super().__init__('This interaction has already been responded to before') + + +class MissingApplicationID(ClientException): + """An exception raised when the client does not have an application ID set. + + An application ID is required for syncing application commands and various + other application tasks such as SKUs or application emojis. + + This inherits from :exc:`~discord.app_commands.AppCommandError` + and :class:`~discord.ClientException`. + + .. versionadded:: 2.0 + + .. versionchanged:: 2.5 + This is now exported to the ``discord`` namespace and now inherits from :class:`~discord.ClientException`. + """ + + def __init__(self, message: Optional[str] = None): + super().__init__(message or APP_ID_NOT_FOUND) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index ad9c286ee..5a74fa5f3 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -87,11 +87,15 @@ class DeferTyping: self.ctx: Context[BotT] = ctx self.ephemeral: bool = ephemeral + async def do_defer(self) -> None: + if self.ctx.interaction and not self.ctx.interaction.response.is_done(): + await self.ctx.interaction.response.defer(ephemeral=self.ephemeral) + def __await__(self) -> Generator[Any, None, None]: - return self.ctx.defer(ephemeral=self.ephemeral).__await__() + return self.do_defer().__await__() async def __aenter__(self) -> None: - await self.ctx.defer(ephemeral=self.ephemeral) + await self.do_defer() async def __aexit__( self, diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 6c559009d..744a00fd3 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -82,6 +82,7 @@ __all__ = ( 'GuildChannelConverter', 'GuildStickerConverter', 'ScheduledEventConverter', + 'SoundboardSoundConverter', 'clean_content', 'Greedy', 'Range', @@ -951,6 +952,44 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]): return result +class SoundboardSoundConverter(IDConverter[discord.SoundboardSound]): + """Converts to a :class:`~discord.SoundboardSound`. + + Lookups are done for the local guild if available. Otherwise, for a DM context, + lookup is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by name. + + .. versionadded:: 2.5 + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.SoundboardSound: + guild = ctx.guild + match = self._get_id_match(argument) + result = None + + if match: + # ID match + sound_id = int(match.group(1)) + if guild: + result = guild.get_soundboard_sound(sound_id) + else: + result = ctx.bot.get_soundboard_sound(sound_id) + else: + # lookup by name + if guild: + result = discord.utils.get(guild.soundboard_sounds, name=argument) + else: + result = discord.utils.get(ctx.bot.soundboard_sounds, name=argument) + if result is None: + raise SoundboardSoundNotFound(argument) + + return result + + class clean_content(Converter[str]): """Converts the argument to mention scrubbed version of said content. @@ -1263,6 +1302,7 @@ CONVERTER_MAPPING: Dict[type, Any] = { discord.GuildSticker: GuildStickerConverter, discord.ScheduledEvent: ScheduledEventConverter, discord.ForumChannel: ForumChannelConverter, + discord.SoundboardSound: SoundboardSoundConverter, } diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 9a73a2b4e..cf328d9b3 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -71,7 +71,7 @@ class BucketType(Enum): elif self is BucketType.member: return ((msg.guild and msg.guild.id), msg.author.id) elif self is BucketType.category: - return (msg.channel.category or msg.channel).id # type: ignore + return (getattr(msg.channel, 'category', None) or msg.channel).id elif self is BucketType.role: # we return the channel id of a private-channel as there are only roles in guilds # and that yields the same result as for a guild with only the @everyone role diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 0c1e0f2d0..f81d54b4d 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -75,6 +75,7 @@ __all__ = ( 'EmojiNotFound', 'GuildStickerNotFound', 'ScheduledEventNotFound', + 'SoundboardSoundNotFound', 'PartialEmojiConversionFailure', 'BadBoolArgument', 'MissingRole', @@ -564,6 +565,24 @@ class ScheduledEventNotFound(BadArgument): super().__init__(f'ScheduledEvent "{argument}" not found.') +class SoundboardSoundNotFound(BadArgument): + """Exception raised when the bot can not find the soundboard sound. + + This inherits from :exc:`BadArgument` + + .. versionadded:: 2.5 + + Attributes + ----------- + argument: :class:`str` + The sound supplied by the caller that was not found + """ + + def __init__(self, argument: str) -> None: + self.argument: str = argument + super().__init__(f'SoundboardSound "{argument}" not found.') + + class BadBoolArgument(BadArgument): """Exception raised when a boolean argument was not convertable. @@ -1061,7 +1080,7 @@ class ExtensionNotFound(ExtensionError): """ def __init__(self, name: str) -> None: - msg = f'Extension {name!r} could not be loaded.' + msg = f'Extension {name!r} could not be loaded or found.' super().__init__(msg, name=name) diff --git a/discord/ext/commands/hybrid.py b/discord/ext/commands/hybrid.py index 8c2f9a9e9..af9e63a7b 100644 --- a/discord/ext/commands/hybrid.py +++ b/discord/ext/commands/hybrid.py @@ -43,7 +43,7 @@ import inspect from discord import app_commands from discord.utils import MISSING, maybe_coroutine, async_all from .core import Command, Group -from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError +from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError, DisabledCommand from .converter import Converter, Range, Greedy, run_converters, CONVERTER_MAPPING from .parameters import Parameter from .flags import is_flag, FlagConverter @@ -234,6 +234,12 @@ def replace_parameter( descriptions[name] = flag.description if flag.name != flag.attribute: renames[name] = flag.name + if pseudo.default is not pseudo.empty: + # This ensures the default is wrapped around _CallableDefault if callable + # else leaves it as-is. + pseudo = pseudo.replace( + default=_CallableDefault(flag.default) if callable(flag.default) else flag.default + ) mapping[name] = pseudo @@ -283,7 +289,7 @@ def replace_parameters( param = param.replace(default=default) if isinstance(param.default, Parameter): - # If we're here, then then it hasn't been handled yet so it should be removed completely + # If we're here, then it hasn't been handled yet so it should be removed completely param = param.replace(default=parameter.empty) # Flags are flattened out and thus don't get their parameter in the actual mapping @@ -526,6 +532,9 @@ class HybridCommand(Command[CogT, P, T]): self.app_command.binding = value async def can_run(self, ctx: Context[BotT], /) -> bool: + if not self.enabled: + raise DisabledCommand(f'{self.name} command is disabled') + if ctx.interaction is not None and self.app_command: return await self.app_command._check_can_run(ctx.interaction) else: diff --git a/discord/ext/commands/parameters.py b/discord/ext/commands/parameters.py index 70df39534..196530d94 100644 --- a/discord/ext/commands/parameters.py +++ b/discord/ext/commands/parameters.py @@ -135,7 +135,7 @@ class Parameter(inspect.Parameter): if displayed_name is MISSING: displayed_name = self._displayed_name - return self.__class__( + ret = self.__class__( name=name, kind=kind, default=default, @@ -144,6 +144,8 @@ class Parameter(inspect.Parameter): displayed_default=displayed_default, displayed_name=displayed_name, ) + ret._fallback = self._fallback + return ret if not TYPE_CHECKING: # this is to prevent anything breaking if inspect internals change name = _gen_property('name') @@ -247,6 +249,12 @@ def parameter( .. versionadded:: 2.3 """ + if isinstance(default, Parameter): + if displayed_default is empty: + displayed_default = default._displayed_default + + default = default._default + return Parameter( name='empty', kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 81644da3a..57f9e741b 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -111,12 +111,17 @@ class SleepHandle: self.loop: asyncio.AbstractEventLoop = loop self.future: asyncio.Future[None] = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = loop.call_later(relative_delta, self.future.set_result, None) + self.handle = loop.call_later(relative_delta, self._wrapped_set_result, self.future) + + @staticmethod + def _wrapped_set_result(future: asyncio.Future) -> None: + if not future.done(): + future.set_result(None) def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) - self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, None) + self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self._wrapped_set_result, self.future) def wait(self) -> asyncio.Future[Any]: return self.future diff --git a/discord/flags.py b/discord/flags.py index 3d31e3a58..de806ba9c 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -135,7 +135,7 @@ class BaseFlags: setattr(self, key, value) @classmethod - def _from_value(cls, value): + def _from_value(cls, value: int) -> Self: self = cls.__new__(cls) self.value = value return self @@ -490,6 +490,14 @@ class MessageFlags(BaseFlags): """ return 8192 + @flag_value + def forwarded(self): + """:class:`bool`: Returns ``True`` if the message is a forwarded message. + + .. versionadded:: 2.5 + """ + return 16384 + @fill_with_flags() class PublicUserFlags(BaseFlags): @@ -871,34 +879,52 @@ class Intents(BaseFlags): @alias_flag_value def emojis(self): - """:class:`bool`: Alias of :attr:`.emojis_and_stickers`. + """:class:`bool`: Alias of :attr:`.expressions`. .. versionchanged:: 2.0 Changed to an alias. """ return 1 << 3 - @flag_value + @alias_flag_value def emojis_and_stickers(self): - """:class:`bool`: Whether guild emoji and sticker related events are enabled. + """:class:`bool`: Alias of :attr:`.expressions`. .. versionadded:: 2.0 + .. versionchanged:: 2.5 + Changed to an alias. + """ + return 1 << 3 + + @flag_value + def expressions(self): + """:class:`bool`: Whether guild emoji, sticker, and soundboard sound related events are enabled. + + .. versionadded:: 2.5 + This corresponds to the following events: - :func:`on_guild_emojis_update` - :func:`on_guild_stickers_update` + - :func:`on_soundboard_sound_create` + - :func:`on_soundboard_sound_update` + - :func:`on_soundboard_sound_delete` This also corresponds to the following attributes and classes in terms of cache: - :class:`Emoji` - :class:`GuildSticker` + - :class:`SoundboardSound` - :meth:`Client.get_emoji` - :meth:`Client.get_sticker` + - :meth:`Client.get_soundboard_sound` - :meth:`Client.emojis` - :meth:`Client.stickers` + - :meth:`Client.soundboard_sounds` - :attr:`Guild.emojis` - :attr:`Guild.stickers` + - :attr:`Guild.soundboard_sounds` """ return 1 << 3 @@ -2032,6 +2058,48 @@ class MemberFlags(BaseFlags): """:class:`bool`: Returns ``True`` if the member has started onboarding.""" return 1 << 3 + @flag_value + def guest(self): + """:class:`bool`: Returns ``True`` if the member is a guest and can only access + the voice channel they were invited to. + + .. versionadded:: 2.5 + """ + return 1 << 4 + + @flag_value + def started_home_actions(self): + """:class:`bool`: Returns ``True`` if the member has started Server Guide new member actions. + + .. versionadded:: 2.5 + """ + return 1 << 5 + + @flag_value + def completed_home_actions(self): + """:class:`bool`: Returns ``True`` if the member has completed Server Guide new member actions. + + .. versionadded:: 2.5 + """ + return 1 << 6 + + @flag_value + def automod_quarantined_username(self): + """:class:`bool`: Returns ``True`` if the member's username, nickname, or global name has been + blocked by AutoMod. + + .. versionadded:: 2.5 + """ + return 1 << 7 + + @flag_value + def dm_settings_upsell_acknowledged(self): + """:class:`bool`: Returns ``True`` if the member has dismissed the DM settings upsell. + + .. versionadded:: 2.5 + """ + return 1 << 9 + @fill_with_flags() class AttachmentFlags(BaseFlags): diff --git a/discord/gateway.py b/discord/gateway.py index b8936bf57..d15c617d1 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import asyncio @@ -32,7 +33,6 @@ import sys import time import threading import traceback -import zlib from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple @@ -295,19 +295,19 @@ class DiscordWebSocket: # fmt: off DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/') - DISPATCH = 0 - HEARTBEAT = 1 - IDENTIFY = 2 - PRESENCE = 3 - VOICE_STATE = 4 - VOICE_PING = 5 - RESUME = 6 - RECONNECT = 7 - REQUEST_MEMBERS = 8 - INVALIDATE_SESSION = 9 - HELLO = 10 - HEARTBEAT_ACK = 11 - GUILD_SYNC = 12 + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE = 3 + VOICE_STATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQUEST_MEMBERS = 8 + INVALIDATE_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 # fmt: on def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: @@ -325,8 +325,7 @@ class DiscordWebSocket: # ws related stuff self.session_id: Optional[str] = None self.sequence: Optional[int] = None - self._zlib: zlib._Decompress = zlib.decompressobj() - self._buffer: bytearray = bytearray() + self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext() self._close_code: Optional[int] = None self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() @@ -355,7 +354,7 @@ class DiscordWebSocket: sequence: Optional[int] = None, resume: bool = False, encoding: str = 'json', - zlib: bool = True, + compress: bool = True, ) -> Self: """Creates a main websocket for Discord from a :class:`Client`. @@ -366,10 +365,12 @@ class DiscordWebSocket: gateway = gateway or cls.DEFAULT_GATEWAY - if zlib: - url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') - else: + if not compress: url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) + else: + url = gateway.with_query( + v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE + ) socket = await client.http.ws_connect(str(url)) ws = cls(socket, loop=client.loop) @@ -488,13 +489,11 @@ class DiscordWebSocket: async def received_message(self, msg: Any, /) -> None: if type(msg) is bytes: - self._buffer.extend(msg) + msg = self._decompressor.decompress(msg) - if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + # Received a partial gateway message + if msg is None: return - msg = self._zlib.decompress(self._buffer) - msg = msg.decode('utf-8') - self._buffer = bytearray() self.log_receive(msg) msg = utils._from_json(msg) @@ -607,7 +606,10 @@ class DiscordWebSocket: def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code - return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) + # If the socket is closed remotely with 1000 and it's not our own explicit close + # then it's an improper close that should be handled and reconnected + is_improper_close = self._close_code is None and self.socket.close_code == 1000 + return is_improper_close or code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) async def poll_event(self) -> None: """Polls for a DISPATCH event and handles the general gateway loop. diff --git a/discord/guild.py b/discord/guild.py index 98a03d5f7..fbed39771 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -94,10 +94,12 @@ from .object import OLDEST_OBJECT, Object from .welcome_screen import WelcomeScreen, WelcomeChannel from .automod import AutoModRule, AutoModTrigger, AutoModRuleAction from .partial_emoji import _EmojiTag, PartialEmoji +from .soundboard import SoundboardSound __all__ = ( 'Guild', + 'GuildPreview', 'BanEntry', ) @@ -108,6 +110,7 @@ if TYPE_CHECKING: from .types.guild import ( Ban as BanPayload, Guild as GuildPayload, + GuildPreview as GuildPreviewPayload, RolePositionUpdate as RolePositionUpdatePayload, GuildFeature, IncidentData, @@ -159,6 +162,121 @@ class _GuildLimit(NamedTuple): filesize: int +class GuildPreview(Hashable): + """Represents a preview of a Discord guild. + + .. versionadded:: 2.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two guild previews are equal. + + .. describe:: x != y + + Checks if two guild previews are not equal. + + .. describe:: hash(x) + + Returns the guild's hash. + + .. describe:: str(x) + + Returns the guild's name. + + Attributes + ---------- + name: :class:`str` + The guild preview's name. + id: :class:`int` + The guild preview's ID. + features: List[:class:`str`] + A list of features the guild has. See :attr:`Guild.features` for more information. + description: Optional[:class:`str`] + The guild preview's description. + emojis: Tuple[:class:`Emoji`, ...] + All emojis that the guild owns. + stickers: Tuple[:class:`GuildSticker`, ...] + All stickers that the guild owns. + approximate_member_count: :class:`int` + The approximate number of members in the guild. + approximate_presence_count: :class:`int` + The approximate number of members currently active in in the guild. Offline members are excluded. + """ + + __slots__ = ( + '_state', + '_icon', + '_splash', + '_discovery_splash', + 'id', + 'name', + 'emojis', + 'stickers', + 'features', + 'description', + "approximate_member_count", + "approximate_presence_count", + ) + + def __init__(self, *, data: GuildPreviewPayload, state: ConnectionState) -> None: + self._state: ConnectionState = state + self.id = int(data['id']) + self.name: str = data['name'] + self._icon: Optional[str] = data.get('icon') + self._splash: Optional[str] = data.get('splash') + self._discovery_splash: Optional[str] = data.get('discovery_splash') + self.emojis: Tuple[Emoji, ...] = tuple( + map( + lambda d: Emoji(guild=state._get_or_create_unavailable_guild(self.id), state=state, data=d), + data.get('emojis', []), + ) + ) + self.stickers: Tuple[GuildSticker, ...] = tuple( + map(lambda d: GuildSticker(state=state, data=d), data.get('stickers', [])) + ) + self.features: List[GuildFeature] = data.get('features', []) + self.description: Optional[str] = data.get('description') + self.approximate_member_count: int = data.get('approximate_member_count') + self.approximate_presence_count: int = data.get('approximate_presence_count') + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return ( + f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r} ' + f'features={self.features}>' + ) + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the guild's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def icon(self) -> Optional[Asset]: + """Optional[:class:`Asset`]: Returns the guild's icon asset, if available.""" + if self._icon is None: + return None + return Asset._from_guild_icon(self._state, self.id, self._icon) + + @property + def splash(self) -> Optional[Asset]: + """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" + if self._splash is None: + return None + return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') + + @property + def discovery_splash(self) -> Optional[Asset]: + """Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available.""" + if self._discovery_splash is None: + return None + return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes') + + class Guild(Hashable): """Represents a Discord guild. @@ -328,6 +446,7 @@ class Guild(Hashable): '_safety_alerts_channel_id', 'max_stage_video_users', '_incidents_data', + '_soundboard_sounds', ) _PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = { @@ -345,6 +464,7 @@ class Guild(Hashable): self._threads: Dict[int, Thread] = {} self._stage_instances: Dict[int, StageInstance] = {} self._scheduled_events: Dict[int, ScheduledEvent] = {} + self._soundboard_sounds: Dict[int, SoundboardSound] = {} self._state: ConnectionState = state self._member_count: Optional[int] = None self._from_data(data) @@ -390,6 +510,12 @@ class Guild(Hashable): del self._threads[k] return to_remove + def _add_soundboard_sound(self, sound: SoundboardSound, /) -> None: + self._soundboard_sounds[sound.id] = sound + + def _remove_soundboard_sound(self, sound: SoundboardSound, /) -> None: + self._soundboard_sounds.pop(sound.id, None) + def __str__(self) -> str: return self.name or '' @@ -547,6 +673,11 @@ class Guild(Hashable): scheduled_event = ScheduledEvent(data=s, state=self._state) self._scheduled_events[scheduled_event.id] = scheduled_event + if 'soundboard_sounds' in guild: + for s in guild['soundboard_sounds']: + soundboard_sound = SoundboardSound(guild=self, data=s, state=self._state) + self._add_soundboard_sound(soundboard_sound) + @property def channels(self) -> Sequence[GuildChannel]: """Sequence[:class:`abc.GuildChannel`]: A list of channels that belongs to this guild.""" @@ -996,6 +1127,37 @@ class Guild(Hashable): """ return self._scheduled_events.get(scheduled_event_id) + @property + def soundboard_sounds(self) -> Sequence[SoundboardSound]: + """Sequence[:class:`SoundboardSound`]: Returns a sequence of the guild's soundboard sounds. + + .. versionadded:: 2.5 + """ + return utils.SequenceProxy(self._soundboard_sounds.values()) + + def get_soundboard_sound(self, sound_id: int, /) -> Optional[SoundboardSound]: + """Returns a soundboard sound with the given ID. + + .. versionadded:: 2.5 + + Parameters + ----------- + sound_id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`SoundboardSound`] + The soundboard sound or ``None`` if not found. + """ + return self._soundboard_sounds.get(sound_id) + + def _resolve_soundboard_sound(self, id: Optional[int], /) -> Optional[SoundboardSound]: + if id is None: + return + + return self._soundboard_sounds.get(id) + @property def owner(self) -> Optional[Member]: """Optional[:class:`Member`]: The member that owns the guild.""" @@ -4322,6 +4484,8 @@ class Guild(Hashable): ------- Forbidden You do not have permission to view the automod rule. + NotFound + The automod rule does not exist within this guild. Returns -------- @@ -4510,3 +4674,130 @@ class Guild(Hashable): return False return self.raid_detected_at > utils.utcnow() + + async def fetch_soundboard_sound(self, sound_id: int, /) -> SoundboardSound: + """|coro| + + Retrieves a :class:`SoundboardSound` with the specified ID. + + .. versionadded:: 2.5 + + .. note:: + + Using this, in order to receive :attr:`SoundboardSound.user`, you must have :attr:`~Permissions.create_expressions` + or :attr:`~Permissions.manage_expressions`. + + .. note:: + + This method is an API call. For general usage, consider :attr:`get_soundboard_sound` instead. + + Raises + ------- + NotFound + The sound requested could not be found. + HTTPException + Retrieving the sound failed. + + Returns + -------- + :class:`SoundboardSound` + The retrieved sound. + """ + data = await self._state.http.get_soundboard_sound(self.id, sound_id) + return SoundboardSound(guild=self, state=self._state, data=data) + + async def fetch_soundboard_sounds(self) -> List[SoundboardSound]: + """|coro| + + Retrieves a list of all soundboard sounds for the guild. + + .. versionadded:: 2.5 + + .. note:: + + Using this, in order to receive :attr:`SoundboardSound.user`, you must have :attr:`~Permissions.create_expressions` + or :attr:`~Permissions.manage_expressions`. + + .. note:: + + This method is an API call. For general usage, consider :attr:`soundboard_sounds` instead. + + Raises + ------- + HTTPException + Retrieving the sounds failed. + + Returns + -------- + List[:class:`SoundboardSound`] + The retrieved soundboard sounds. + """ + data = await self._state.http.get_soundboard_sounds(self.id) + return [SoundboardSound(guild=self, state=self._state, data=sound) for sound in data['items']] + + async def create_soundboard_sound( + self, + *, + name: str, + sound: bytes, + volume: float = 1, + emoji: Optional[EmojiInputType] = None, + reason: Optional[str] = None, + ) -> SoundboardSound: + """|coro| + + Creates a :class:`SoundboardSound` for the guild. + You must have :attr:`Permissions.create_expressions` to do this. + + .. versionadded:: 2.5 + + Parameters + ---------- + name: :class:`str` + The name of the sound. Must be between 2 and 32 characters. + sound: :class:`bytes` + The :term:`py:bytes-like object` representing the sound data. + Only MP3 and OGG sound files that don't exceed the duration of 5.2s are supported. + volume: :class:`float` + The volume of the sound. Must be between 0 and 1. Defaults to ``1``. + emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`]] + The emoji of the sound. + reason: Optional[:class:`str`] + The reason for creating the sound. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to create a soundboard sound. + HTTPException + Creating the soundboard sound failed. + + Returns + ------- + :class:`SoundboardSound` + The newly created soundboard sound. + """ + payload: Dict[str, Any] = { + 'name': name, + 'sound': utils._bytes_to_base64_data(sound, audio=True), + 'volume': volume, + 'emoji_id': None, + 'emoji_name': None, + } + + if emoji is not None: + if isinstance(emoji, _EmojiTag): + partial_emoji = emoji._to_partial() + elif isinstance(emoji, str): + partial_emoji = PartialEmoji.from_str(emoji) + else: + partial_emoji = None + + if partial_emoji is not None: + if partial_emoji.id is None: + payload['emoji_name'] = partial_emoji.name + else: + payload['emoji_id'] = partial_emoji.id + + data = await self._state.http.create_soundboard_sound(self.id, reason=reason, **payload) + return SoundboardSound(guild=self, state=self._state, data=data) diff --git a/discord/http.py b/discord/http.py index 1d4299549..1809639ed 100644 --- a/discord/http.py +++ b/discord/http.py @@ -93,8 +93,11 @@ if TYPE_CHECKING: sku, poll, voice, + soundboard, + subscription, ) from .types.snowflake import Snowflake, SnowflakeList + from .types.gateway import SessionStartLimit from types import TracebackType @@ -195,6 +198,7 @@ def handle_message_parameters( if nonce is not None: payload['nonce'] = str(nonce) + payload['enforce_nonce'] = True if message_reference is not MISSING: payload['message_reference'] = message_reference @@ -306,7 +310,7 @@ class Route: self.metadata: Optional[str] = metadata url = self.BASE + self.path if parameters: - url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()}) + url = url.format_map({k: _uriquote(v, safe='') if isinstance(v, str) else v for k, v in parameters.items()}) self.url: str = url # major parameters: @@ -775,7 +779,15 @@ class HTTPClient: raise RuntimeError('Unreachable code in HTTP handling') async def get_from_cdn(self, url: str) -> bytes: - async with self.__session.get(url) as resp: + kwargs = {} + + # Proxy support + if self.proxy is not None: + kwargs['proxy'] = self.proxy + if self.proxy_auth is not None: + kwargs['proxy_auth'] = self.proxy_auth + + async with self.__session.get(url, **kwargs) as resp: if resp.status == 200: return await resp.read() elif resp.status == 404: @@ -1441,6 +1453,9 @@ class HTTPClient: params = {'with_counts': int(with_counts)} return self.request(Route('GET', '/guilds/{guild_id}', guild_id=guild_id), params=params) + def get_guild_preview(self, guild_id: Snowflake) -> Response[guild.GuildPreview]: + return self.request(Route('GET', '/guilds/{guild_id}/preview', guild_id=guild_id)) + def delete_guild(self, guild_id: Snowflake) -> Response[None]: return self.request(Route('DELETE', '/guilds/{guild_id}', guild_id=guild_id)) @@ -2447,6 +2462,7 @@ class HTTPClient: limit: Optional[int] = None, guild_id: Optional[Snowflake] = None, exclude_ended: Optional[bool] = None, + exclude_deleted: Optional[bool] = None, ) -> Response[List[sku.Entitlement]]: params: Dict[str, Any] = {} @@ -2464,6 +2480,8 @@ class HTTPClient: params['guild_id'] = guild_id if exclude_ended is not None: params['exclude_ended'] = int(exclude_ended) + if exclude_deleted is not None: + params['exclude_deleted'] = int(exclude_deleted) return self.request( Route('GET', '/applications/{application_id}/entitlements', application_id=application_id), params=params @@ -2517,6 +2535,77 @@ class HTTPClient: ), ) + # Soundboard + + def get_soundboard_default_sounds(self) -> Response[List[soundboard.SoundboardDefaultSound]]: + return self.request(Route('GET', '/soundboard-default-sounds')) + + def get_soundboard_sound(self, guild_id: Snowflake, sound_id: Snowflake) -> Response[soundboard.SoundboardSound]: + return self.request( + Route('GET', '/guilds/{guild_id}/soundboard-sounds/{sound_id}', guild_id=guild_id, sound_id=sound_id) + ) + + def get_soundboard_sounds(self, guild_id: Snowflake) -> Response[Dict[str, List[soundboard.SoundboardSound]]]: + return self.request(Route('GET', '/guilds/{guild_id}/soundboard-sounds', guild_id=guild_id)) + + def create_soundboard_sound( + self, guild_id: Snowflake, *, reason: Optional[str], **payload: Any + ) -> Response[soundboard.SoundboardSound]: + valid_keys = ( + 'name', + 'sound', + 'volume', + 'emoji_id', + 'emoji_name', + ) + + payload = {k: v for k, v in payload.items() if k in valid_keys and v is not None} + + return self.request( + Route('POST', '/guilds/{guild_id}/soundboard-sounds', guild_id=guild_id), json=payload, reason=reason + ) + + def edit_soundboard_sound( + self, guild_id: Snowflake, sound_id: Snowflake, *, reason: Optional[str], **payload: Any + ) -> Response[soundboard.SoundboardSound]: + valid_keys = ( + 'name', + 'volume', + 'emoji_id', + 'emoji_name', + ) + + payload = {k: v for k, v in payload.items() if k in valid_keys} + + return self.request( + Route( + 'PATCH', + '/guilds/{guild_id}/soundboard-sounds/{sound_id}', + guild_id=guild_id, + sound_id=sound_id, + ), + json=payload, + reason=reason, + ) + + def delete_soundboard_sound(self, guild_id: Snowflake, sound_id: Snowflake, *, reason: Optional[str]) -> Response[None]: + return self.request( + Route( + 'DELETE', + '/guilds/{guild_id}/soundboard-sounds/{sound_id}', + guild_id=guild_id, + sound_id=sound_id, + ), + reason=reason, + ) + + def send_soundboard_sound(self, channel_id: Snowflake, **payload: Any) -> Response[None]: + valid_keys = ('sound_id', 'source_guild_id') + payload = {k: v for k, v in payload.items() if k in valid_keys} + return self.request( + (Route('POST', '/channels/{channel_id}/send-soundboard-sound', channel_id=channel_id)), json=payload + ) + # Application def application_info(self) -> Response[appinfo.AppInfo]: @@ -2628,30 +2717,58 @@ class HTTPClient: ) ) - # Misc + # Subscriptions - async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: - try: - data = await self.request(Route('GET', '/gateway')) - except HTTPException as exc: - raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' - else: - value = '{0}?encoding={1}&v={2}' - return value.format(data['url'], encoding, INTERNAL_API_VERSION) + def list_sku_subscriptions( + self, + sku_id: Snowflake, + before: Optional[Snowflake] = None, + after: Optional[Snowflake] = None, + limit: Optional[int] = None, + user_id: Optional[Snowflake] = None, + ) -> Response[List[subscription.Subscription]]: + params = {} + + if before is not None: + params['before'] = before + + if after is not None: + params['after'] = after - async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]: + if limit is not None: + params['limit'] = limit + + if user_id is not None: + params['user_id'] = user_id + + return self.request( + Route( + 'GET', + '/skus/{sku_id}/subscriptions', + sku_id=sku_id, + ), + params=params, + ) + + def get_sku_subscription(self, sku_id: Snowflake, subscription_id: Snowflake) -> Response[subscription.Subscription]: + return self.request( + Route( + 'GET', + '/skus/{sku_id}/subscriptions/{subscription_id}', + sku_id=sku_id, + subscription_id=subscription_id, + ) + ) + + # Misc + + async def get_bot_gateway(self) -> Tuple[int, str, SessionStartLimit]: try: data = await self.request(Route('GET', '/gateway/bot')) except HTTPException as exc: raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' - else: - value = '{0}?encoding={1}&v={2}' - return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION) + return data['shards'], data['url'], data['session_start_limit'] def get_user(self, user_id: Snowflake) -> Response[user.User]: return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) diff --git a/discord/interactions.py b/discord/interactions.py index e0fb7ed86..49bfbfb07 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -192,6 +192,9 @@ class Interaction(Generic[ClientT]): self.command_failed: bool = False self._from_data(data) + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id} type={self.type!r} guild_id={self.guild_id!r} user={self.user!r}>' + def _from_data(self, data: InteractionPayload): self.id: int = int(data['id']) self.type: InteractionType = try_enum(InteractionType, data['type']) diff --git a/discord/member.py b/discord/member.py index 66c771572..6118e3267 100644 --- a/discord/member.py +++ b/discord/member.py @@ -303,7 +303,7 @@ class Member(discord.abc.Messageable, _UserTag): "Nitro boost" on the guild, if available. This could be ``None``. timed_out_until: Optional[:class:`datetime.datetime`] An aware datetime object that specifies the date and time in UTC that the member's time out will expire. - This will be set to ``None`` if the user is not timed out. + This will be set to ``None`` or a time in the past if the user is not timed out. .. versionadded:: 2.0 """ @@ -1076,7 +1076,7 @@ class Member(discord.abc.Messageable, _UserTag): You must have :attr:`~Permissions.manage_roles` to use this, and the added :class:`Role`\s must appear lower in the list - of roles than the highest role of the member. + of roles than the highest role of the client. Parameters ----------- @@ -1115,7 +1115,7 @@ class Member(discord.abc.Messageable, _UserTag): You must have :attr:`~Permissions.manage_roles` to use this, and the removed :class:`Role`\s must appear lower in the list - of roles than the highest role of the member. + of roles than the highest role of the client. Parameters ----------- diff --git a/discord/message.py b/discord/message.py index 0247ee8c1..3d755e314 100644 --- a/discord/message.py +++ b/discord/message.py @@ -32,6 +32,7 @@ from os import PathLike from typing import ( Dict, TYPE_CHECKING, + Literal, Sequence, Union, List, @@ -49,7 +50,7 @@ from .asset import Asset from .reaction import Reaction from .emoji import Emoji from .partial_emoji import PartialEmoji -from .enums import InteractionType, MessageType, ChannelType, try_enum +from .enums import InteractionType, MessageReferenceType, MessageType, ChannelType, try_enum from .errors import HTTPException from .components import _component_factory from .embeds import Embed @@ -72,10 +73,14 @@ if TYPE_CHECKING: Message as MessagePayload, Attachment as AttachmentPayload, MessageReference as MessageReferencePayload, + MessageSnapshot as MessageSnapshotPayload, MessageApplication as MessageApplicationPayload, MessageActivity as MessageActivityPayload, RoleSubscriptionData as RoleSubscriptionDataPayload, MessageInteractionMetadata as MessageInteractionMetadataPayload, + CallMessage as CallMessagePayload, + PurchaseNotificationResponse as PurchaseNotificationResponsePayload, + GuildProductPurchase as GuildProductPurchasePayload, ) from .types.interactions import MessageInteraction as MessageInteractionPayload @@ -108,10 +113,14 @@ __all__ = ( 'PartialMessage', 'MessageInteraction', 'MessageReference', + 'MessageSnapshot', 'DeletedReferencedMessage', 'MessageApplication', 'RoleSubscriptionInfo', 'MessageInteractionMetadata', + 'CallMessage', + 'GuildProductPurchase', + 'PurchaseNotification', ) @@ -458,6 +467,133 @@ class DeletedReferencedMessage: return self._parent.guild_id +class MessageSnapshot: + """Represents a message snapshot attached to a forwarded message. + + .. versionadded:: 2.5 + + Attributes + ----------- + type: :class:`MessageType` + The type of the forwarded message. + content: :class:`str` + The actual contents of the forwarded message. + embeds: List[:class:`Embed`] + A list of embeds the forwarded message has. + attachments: List[:class:`Attachment`] + A list of attachments given to the forwarded message. + created_at: :class:`datetime.datetime` + The forwarded message's time of creation. + flags: :class:`MessageFlags` + Extra features of the the message snapshot. + stickers: List[:class:`StickerItem`] + A list of sticker items given to the message. + components: List[Union[:class:`ActionRow`, :class:`Button`, :class:`SelectMenu`]] + A list of components in the message. + """ + + __slots__ = ( + '_cs_raw_channel_mentions', + '_cs_cached_message', + '_cs_raw_mentions', + '_cs_raw_role_mentions', + '_edited_timestamp', + 'attachments', + 'content', + 'embeds', + 'flags', + 'created_at', + 'type', + 'stickers', + 'components', + '_state', + ) + + @classmethod + def _from_value( + cls, + state: ConnectionState, + message_snapshots: Optional[List[Dict[Literal['message'], MessageSnapshotPayload]]], + ) -> List[Self]: + if not message_snapshots: + return [] + + return [cls(state, snapshot['message']) for snapshot in message_snapshots] + + def __init__(self, state: ConnectionState, data: MessageSnapshotPayload): + self.type: MessageType = try_enum(MessageType, data['type']) + self.content: str = data['content'] + self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] + self.attachments: List[Attachment] = [Attachment(data=a, state=state) for a in data['attachments']] + self.created_at: datetime.datetime = utils.parse_time(data['timestamp']) + self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) + self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0)) + self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] + + self.components: List[MessageComponentType] = [] + for component_data in data.get('components', []): + component = _component_factory(component_data) + if component is not None: + self.components.append(component) + + self._state: ConnectionState = state + + def __repr__(self) -> str: + name = self.__class__.__name__ + return f'<{name} type={self.type!r} created_at={self.created_at!r} flags={self.flags!r}>' + + @utils.cached_slot_property('_cs_raw_mentions') + def raw_mentions(self) -> List[int]: + """List[:class:`int`]: A property that returns an array of user IDs matched with + the syntax of ``<@user_id>`` in the message content. + + This allows you to receive the user IDs of mentioned users + even in a private message context. + """ + return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)] + + @utils.cached_slot_property('_cs_raw_channel_mentions') + def raw_channel_mentions(self) -> List[int]: + """List[:class:`int`]: A property that returns an array of channel IDs matched with + the syntax of ``<#channel_id>`` in the message content. + """ + return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)] + + @utils.cached_slot_property('_cs_raw_role_mentions') + def raw_role_mentions(self) -> List[int]: + """List[:class:`int`]: A property that returns an array of role IDs matched with + the syntax of ``<@&role_id>`` in the message content. + """ + return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)] + + @utils.cached_slot_property('_cs_cached_message') + def cached_message(self) -> Optional[Message]: + """Optional[:class:`Message`]: Returns the cached message this snapshot points to, if any.""" + state = self._state + return ( + utils.find( + lambda m: ( + m.created_at == self.created_at + and m.edited_at == self.edited_at + and m.content == self.content + and m.embeds == self.embeds + and m.components == self.components + and m.stickers == self.stickers + and m.attachments == self.attachments + and m.flags == self.flags + ), + reversed(state._messages), + ) + if state._messages + else None + ) + + @property + def edited_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: An aware UTC datetime object containing the edited time of the forwarded message.""" + return self._edited_timestamp + + class MessageReference: """Represents a reference to a :class:`~discord.Message`. @@ -468,6 +604,10 @@ class MessageReference: Attributes ----------- + type: :class:`MessageReferenceType` + The type of message reference. + + .. versionadded:: 2.5 message_id: Optional[:class:`int`] The id of the message referenced. channel_id: :class:`int` @@ -475,7 +615,7 @@ class MessageReference: guild_id: Optional[:class:`int`] The guild id of the message referenced. fail_if_not_exists: :class:`bool` - Whether replying to the referenced message should raise :class:`HTTPException` + Whether the referenced message should raise :class:`HTTPException` if the message no longer exists or Discord could not fetch the message. .. versionadded:: 1.7 @@ -487,15 +627,22 @@ class MessageReference: If the message was resolved at a prior point but has since been deleted then this will be of type :class:`DeletedReferencedMessage`. - Currently, this is mainly the replied to message when a user replies to a message. - .. versionadded:: 1.6 """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state') + __slots__ = ('type', 'message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state') - def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True): + def __init__( + self, + *, + message_id: int, + channel_id: int, + guild_id: Optional[int] = None, + fail_if_not_exists: bool = True, + type: MessageReferenceType = MessageReferenceType.reply, + ): self._state: Optional[ConnectionState] = None + self.type: MessageReferenceType = type self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None self.message_id: Optional[int] = message_id self.channel_id: int = channel_id @@ -505,6 +652,7 @@ class MessageReference: @classmethod def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> Self: self = cls.__new__(cls) + self.type = try_enum(MessageReferenceType, data.get('type', 0)) self.message_id = utils._get_as_snowflake(data, 'message_id') self.channel_id = int(data['channel_id']) self.guild_id = utils._get_as_snowflake(data, 'guild_id') @@ -514,7 +662,13 @@ class MessageReference: return self @classmethod - def from_message(cls, message: PartialMessage, *, fail_if_not_exists: bool = True) -> Self: + def from_message( + cls, + message: PartialMessage, + *, + fail_if_not_exists: bool = True, + type: MessageReferenceType = MessageReferenceType.reply, + ) -> Self: """Creates a :class:`MessageReference` from an existing :class:`~discord.Message`. .. versionadded:: 1.6 @@ -524,10 +678,14 @@ class MessageReference: message: :class:`~discord.Message` The message to be converted into a reference. fail_if_not_exists: :class:`bool` - Whether replying to the referenced message should raise :class:`HTTPException` + Whether the referenced message should raise :class:`HTTPException` if the message no longer exists or Discord could not fetch the message. .. versionadded:: 1.7 + type: :class:`~discord.MessageReferenceType` + The type of message reference this is. + + .. versionadded:: 2.5 Returns ------- @@ -539,6 +697,7 @@ class MessageReference: channel_id=message.channel.id, guild_id=getattr(message.guild, 'id', None), fail_if_not_exists=fail_if_not_exists, + type=type, ) self._state = message._state return self @@ -561,7 +720,9 @@ class MessageReference: return f'' def to_dict(self) -> MessageReferencePayload: - result: Dict[str, Any] = {'message_id': self.message_id} if self.message_id is not None else {} + result: Dict[str, Any] = ( + {'type': self.type.value, 'message_id': self.message_id} if self.message_id is not None else {} + ) result['channel_id'] = self.channel_id if self.guild_id is not None: result['guild_id'] = self.guild_id @@ -667,6 +828,14 @@ class MessageInteractionMetadata(Hashable): The ID of the message that containes the interactive components, if applicable. modal_interaction: Optional[:class:`.MessageInteractionMetadata`] The metadata of the modal submit interaction that triggered this interaction, if applicable. + target_user: Optional[:class:`User`] + The user the command was run on, only applicable to user context menus. + + .. versionadded:: 2.5 + target_message_id: Optional[:class:`int`] + The ID of the message the command was run on, only applicable to message context menus. + + .. versionadded:: 2.5 """ __slots__: Tuple[str, ...] = ( @@ -676,6 +845,8 @@ class MessageInteractionMetadata(Hashable): 'original_response_message_id', 'interacted_message_id', 'modal_interaction', + 'target_user', + 'target_message_id', '_integration_owners', '_state', '_guild', @@ -687,31 +858,43 @@ class MessageInteractionMetadata(Hashable): self.id: int = int(data['id']) self.type: InteractionType = try_enum(InteractionType, data['type']) - self.user = state.create_user(data['user']) + self.user: User = state.create_user(data['user']) self._integration_owners: Dict[int, int] = { int(key): int(value) for key, value in data.get('authorizing_integration_owners', {}).items() } self.original_response_message_id: Optional[int] = None try: - self.original_response_message_id = int(data['original_response_message_id']) + self.original_response_message_id = int(data['original_response_message_id']) # type: ignore # EAFP except KeyError: pass self.interacted_message_id: Optional[int] = None try: - self.interacted_message_id = int(data['interacted_message_id']) + self.interacted_message_id = int(data['interacted_message_id']) # type: ignore # EAFP except KeyError: pass self.modal_interaction: Optional[MessageInteractionMetadata] = None try: self.modal_interaction = MessageInteractionMetadata( - state=state, guild=guild, data=data['triggering_interaction_metadata'] + state=state, guild=guild, data=data['triggering_interaction_metadata'] # type: ignore # EAFP ) except KeyError: pass + self.target_user: Optional[User] = None + try: + self.target_user = state.create_user(data['target_user']) # type: ignore # EAFP + except KeyError: + pass + + self.target_message_id: Optional[int] = None + try: + self.target_message_id = int(data['target_message_id']) # type: ignore # EAFP + except KeyError: + pass + def __repr__(self) -> str: return f'' @@ -738,6 +921,16 @@ class MessageInteractionMetadata(Hashable): return self._state._get_message(self.interacted_message_id) return None + @property + def target_message(self) -> Optional[Message]: + """Optional[:class:`~discord.Message`]: The target message, if applicable and is found in cache. + + .. versionadded:: 2.5 + """ + if self.target_message_id: + return self._state._get_message(self.target_message_id) + return None + def is_guild_integration(self) -> bool: """:class:`bool`: Returns ``True`` if the interaction is a guild integration.""" if self._guild: @@ -810,6 +1003,51 @@ class MessageApplication: return None +class CallMessage: + """Represents a message's call data in a private channel from a :class:`~discord.Message`. + + .. versionadded:: 2.5 + + Attributes + ----------- + ended_timestamp: Optional[:class:`datetime.datetime`] + The timestamp the call has ended. + participants: List[:class:`User`] + A list of users that participated in the call. + """ + + __slots__ = ('_message', 'ended_timestamp', 'participants') + + def __repr__(self) -> str: + return f'' + + def __init__(self, *, state: ConnectionState, message: Message, data: CallMessagePayload): + self._message: Message = message + self.ended_timestamp: Optional[datetime.datetime] = utils.parse_time(data.get('ended_timestamp')) + self.participants: List[User] = [] + + for user_id in data['participants']: + user_id = int(user_id) + if user_id == self._message.author.id: + self.participants.append(self._message.author) # type: ignore # can't be a Member here + else: + user = state.get_user(user_id) + if user is not None: + self.participants.append(user) + + @property + def duration(self) -> datetime.timedelta: + """:class:`datetime.timedelta`: The duration the call has lasted or is already ongoing.""" + if self.ended_timestamp is None: + return utils.utcnow() - self._message.created_at + else: + return self.ended_timestamp - self._message.created_at + + def is_ended(self) -> bool: + """:class:`bool`: Whether the call is ended or not.""" + return self.ended_timestamp is not None + + class RoleSubscriptionInfo: """Represents a message's role subscription information. @@ -843,6 +1081,59 @@ class RoleSubscriptionInfo: self.is_renewal: bool = data['is_renewal'] +class GuildProductPurchase: + """Represents a message's guild product that the user has purchased. + + .. versionadded:: 2.5 + + Attributes + ----------- + listing_id: :class:`int` + The ID of the listing that the user has purchased. + product_name: :class:`str` + The name of the product that the user has purchased. + """ + + __slots__ = ('listing_id', 'product_name') + + def __init__(self, data: GuildProductPurchasePayload) -> None: + self.listing_id: int = int(data['listing_id']) + self.product_name: str = data['product_name'] + + def __hash__(self) -> int: + return self.listing_id >> 22 + + def __eq__(self, other: object) -> bool: + return isinstance(other, GuildProductPurchase) and other.listing_id == self.listing_id + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class PurchaseNotification: + """Represents a message's purchase notification data. + + This is currently only attached to messages of type :attr:`MessageType.purchase_notification`. + + .. versionadded:: 2.5 + + Attributes + ----------- + guild_product_purchase: Optional[:class:`GuildProductPurchase`] + The guild product purchase that prompted the message. + """ + + __slots__ = ('_type', 'guild_product_purchase') + + def __init__(self, data: PurchaseNotificationResponsePayload) -> None: + self._type: int = data['type'] + + self.guild_product_purchase: Optional[GuildProductPurchase] = None + guild_product_purchase = data.get('guild_product_purchase') + if guild_product_purchase is not None: + self.guild_product_purchase = GuildProductPurchase(guild_product_purchase) + + class PartialMessage(Hashable): """Represents a partial message to aid with working messages when only a message and channel ID are present. @@ -1102,6 +1393,8 @@ class PartialMessage(Hashable): Forbidden Tried to suppress a message without permissions or edited a message's content or embed that isn't yours. + NotFound + This message does not exist. TypeError You specified both ``embed`` and ``embeds`` @@ -1593,7 +1886,12 @@ class PartialMessage(Hashable): return Message(state=self._state, channel=self.channel, data=data) - def to_reference(self, *, fail_if_not_exists: bool = True) -> MessageReference: + def to_reference( + self, + *, + fail_if_not_exists: bool = True, + type: MessageReferenceType = MessageReferenceType.reply, + ) -> MessageReference: """Creates a :class:`~discord.MessageReference` from the current message. .. versionadded:: 1.6 @@ -1601,10 +1899,14 @@ class PartialMessage(Hashable): Parameters ---------- fail_if_not_exists: :class:`bool` - Whether replying using the message reference should raise :class:`HTTPException` + Whether the referenced message should raise :class:`HTTPException` if the message no longer exists or Discord could not fetch the message. .. versionadded:: 1.7 + type: :class:`MessageReferenceType` + The type of message reference. + + .. versionadded:: 2.5 Returns --------- @@ -1612,7 +1914,44 @@ class PartialMessage(Hashable): The reference to this message. """ - return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists) + return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists, type=type) + + async def forward( + self, + destination: MessageableChannel, + *, + fail_if_not_exists: bool = True, + ) -> Message: + """|coro| + + Forwards this message to a channel. + + .. versionadded:: 2.5 + + Parameters + ---------- + destination: :class:`~discord.abc.Messageable` + The channel to forward this message to. + fail_if_not_exists: :class:`bool` + Whether replying using the message reference should raise :class:`HTTPException` + if the message no longer exists or Discord could not fetch the message. + + Raises + ------ + ~discord.HTTPException + Forwarding the message failed. + + Returns + ------- + :class:`.Message` + The message sent to the channel. + """ + reference = self.to_reference( + fail_if_not_exists=fail_if_not_exists, + type=MessageReferenceType.forward, + ) + ret = await destination.send(reference=reference) + return ret def to_message_reference_dict(self) -> MessageReferencePayload: data: MessageReferencePayload = { @@ -1768,6 +2107,18 @@ class Message(PartialMessage, Hashable): The poll attached to this message. .. versionadded:: 2.4 + call: Optional[:class:`CallMessage`] + The call associated with this message. + + .. versionadded:: 2.5 + purchase_notification: Optional[:class:`PurchaseNotification`] + The data of the purchase notification that prompted this :attr:`MessageType.purchase_notification` message. + + .. versionadded:: 2.5 + message_snapshots: List[:class:`MessageSnapshot`] + The message snapshots attached to this message. + + .. versionadded:: 2.5 """ __slots__ = ( @@ -1804,6 +2155,9 @@ class Message(PartialMessage, Hashable): 'position', 'interaction_metadata', 'poll', + 'call', + 'purchase_notification', + 'message_snapshots', ) if TYPE_CHECKING: @@ -1842,14 +2196,13 @@ class Message(PartialMessage, Hashable): self.position: Optional[int] = data.get('position') self.application_id: Optional[int] = utils._get_as_snowflake(data, 'application_id') self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] + self.message_snapshots: List[MessageSnapshot] = MessageSnapshot._from_value(state, data.get('message_snapshots')) - # This updates the poll so it has the counts, if the message - # was previously cached. self.poll: Optional[Poll] = None try: self.poll = Poll._from_data(data=data['poll'], message=self, state=state) except KeyError: - self.poll = state._get_poll(self.id) + pass try: # if the channel doesn't have a guild attribute, we handle that @@ -1931,7 +2284,15 @@ class Message(PartialMessage, Hashable): else: self.role_subscription = RoleSubscriptionInfo(role_subscription) - for handler in ('author', 'member', 'mentions', 'mention_roles', 'components'): + self.purchase_notification: Optional[PurchaseNotification] = None + try: + purchase_notification = data['purchase_notification'] + except KeyError: + pass + else: + self.purchase_notification = PurchaseNotification(purchase_notification) + + for handler in ('author', 'member', 'mentions', 'mention_roles', 'components', 'call'): try: getattr(self, f'_handle_{handler}')(data[handler]) except KeyError: @@ -2117,6 +2478,13 @@ class Message(PartialMessage, Hashable): def _handle_interaction_metadata(self, data: MessageInteractionMetadataPayload): self.interaction_metadata = MessageInteractionMetadata(state=self._state, guild=self.guild, data=data) + def _handle_call(self, data: CallMessagePayload): + self.call: Optional[CallMessage] + if data is not None: + self.call = CallMessage(state=self._state, message=self, data=data) + else: + self.call = None + def _rebind_cached_references( self, new_guild: Guild, @@ -2387,10 +2755,10 @@ class Message(PartialMessage, Hashable): return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!' if self.type is MessageType.role_subscription_purchase and self.role_subscription is not None: - # TODO: figure out how the message looks like for is_renewal: true total_months = self.role_subscription.total_months_subscribed months = '1 month' if total_months == 1 else f'{total_months} months' - return f'{self.author.name} joined {self.role_subscription.tier_name} and has been a subscriber of {self.guild} for {months}!' + action = 'renewed' if self.role_subscription.is_renewal else 'joined' + return f'{self.author.name} {action} **{self.role_subscription.tier_name}** and has been a subscriber of {self.guild} for {months}!' if self.type is MessageType.stage_start: return f'{self.author.name} started **{self.content}**.' @@ -2421,6 +2789,27 @@ class Message(PartialMessage, Hashable): if self.type is MessageType.guild_incident_report_false_alarm: return f'{self.author.name} reported a false alarm in {self.guild}.' + if self.type is MessageType.call: + call_ended = self.call.ended_timestamp is not None # type: ignore # call can't be None here + missed = self._state.user not in self.call.participants # type: ignore # call can't be None here + + if call_ended: + duration = utils._format_call_duration(self.call.duration) # type: ignore # call can't be None here + if missed: + return 'You missed a call from {0.author.name} that lasted {1}.'.format(self, duration) + else: + return '{0.author.name} started a call that lasted {1}.'.format(self, duration) + else: + if missed: + return '{0.author.name} started a call. \N{EM DASH} Join the call'.format(self) + else: + return '{0.author.name} started a call.'.format(self) + + if self.type is MessageType.purchase_notification and self.purchase_notification is not None: + guild_product_purchase = self.purchase_notification.guild_product_purchase + if guild_product_purchase is not None: + return f'{self.author.name} has purchased {guild_product_purchase.product_name}!' + # Fallback for unknown message types return '' @@ -2531,6 +2920,8 @@ class Message(PartialMessage, Hashable): Forbidden Tried to suppress a message without permissions or edited a message's content or embed that isn't yours. + NotFound + This message does not exist. TypeError You specified both ``embed`` and ``embeds`` diff --git a/discord/permissions.py b/discord/permissions.py index 17c7b38c9..b553e2578 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -187,7 +187,7 @@ class Permissions(BaseFlags): permissions set to ``True``. """ # Some of these are 0 because we don't want to set unnecessary bits - return cls(0b0000_0000_0000_0010_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111) + return cls(0b0000_0000_0000_0110_0111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111) @classmethod def _timeout_mask(cls) -> int: diff --git a/discord/player.py b/discord/player.py index 5b2c99dc0..bad6da88e 100644 --- a/discord/player.py +++ b/discord/player.py @@ -588,22 +588,26 @@ class FFmpegOpusAudio(FFmpegAudio): loop = asyncio.get_running_loop() try: codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) - except Exception: + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: if not fallback: _log.exception("Probe '%s' using '%s' failed", method, executable) - return # type: ignore + return None, None _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) try: codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) - except Exception: + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: _log.exception("Fallback probe using '%s' failed", executable) else: _log.debug("Fallback probe found codec=%s, bitrate=%s", codec, bitrate) else: _log.debug("Probe found codec=%s, bitrate=%s", codec, bitrate) - finally: - return codec, bitrate + + return codec, bitrate @staticmethod def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: diff --git a/discord/poll.py b/discord/poll.py index 88ed5b534..720f91245 100644 --- a/discord/poll.py +++ b/discord/poll.py @@ -359,6 +359,7 @@ class Poll: # The message's poll contains the more up to date data. self._expiry = message.poll.expires_at self._finalized = message.poll._finalized + self._answers = message.poll._answers def _update_results(self, data: PollResultPayload) -> None: self._finalized = data['is_finalized'] @@ -568,6 +569,7 @@ class Poll: if not self._message or not self._state: # Make type checker happy raise ClientException('This poll has no attached message.') - self._message = await self._message.end_poll() + message = await self._message.end_poll() + self._update(message) return self diff --git a/discord/raw_models.py b/discord/raw_models.py index 8d3ad328f..012b8f07d 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -166,20 +166,22 @@ class RawMessageUpdateEvent(_RawReprMixin): cached_message: Optional[:class:`Message`] The cached message, if found in the internal message cache. Represents the message before it is modified by the data in :attr:`RawMessageUpdateEvent.data`. + message: :class:`Message` + The updated message. + + .. versionadded:: 2.5 """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message') + __slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message', 'message') - def __init__(self, data: MessageUpdateEvent) -> None: - self.message_id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) + def __init__(self, data: MessageUpdateEvent, message: Message) -> None: + self.message_id: int = message.id + self.channel_id: int = message.channel.id self.data: MessageUpdateEvent = data + self.message: Message = message self.cached_message: Optional[Message] = None - try: - self.guild_id: Optional[int] = int(data['guild_id']) - except KeyError: - self.guild_id: Optional[int] = None + self.guild_id: Optional[int] = message.guild.id if message.guild else None class RawReactionActionEvent(_RawReprMixin): diff --git a/discord/shard.py b/discord/shard.py index 52155aa24..454fd5e28 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -47,13 +47,16 @@ from .enums import Status from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict if TYPE_CHECKING: + from typing_extensions import Unpack from .gateway import DiscordWebSocket from .activity import BaseActivity from .flags import Intents + from .types.gateway import SessionStartLimit __all__ = ( 'AutoShardedClient', 'ShardInfo', + 'SessionStartLimits', ) _log = logging.getLogger(__name__) @@ -293,6 +296,32 @@ class ShardInfo: return self._parent.ws.is_ratelimited() +class SessionStartLimits: + """A class that holds info about session start limits + + .. versionadded:: 2.5 + + Attributes + ---------- + total: :class:`int` + The total number of session starts the current user is allowed + remaining: :class:`int` + Remaining remaining number of session starts the current user is allowed + reset_after: :class:`int` + The number of milliseconds until the limit resets + max_concurrency: :class:`int` + The number of identify requests allowed per 5 seconds + """ + + __slots__ = ("total", "remaining", "reset_after", "max_concurrency") + + def __init__(self, **kwargs: Unpack[SessionStartLimit]): + self.total: int = kwargs['total'] + self.remaining: int = kwargs['remaining'] + self.reset_after: int = kwargs['reset_after'] + self.max_concurrency: int = kwargs['max_concurrency'] + + class AutoShardedClient(Client): """A client similar to :class:`Client` except it handles the complications of sharding for the user into a more manageable and transparent single @@ -415,6 +444,33 @@ class AutoShardedClient(Client): """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} + async def fetch_session_start_limits(self) -> SessionStartLimits: + """|coro| + + Get the session start limits. + + This is not typically needed, and will be handled for you by default. + + At the point where you are launching multiple instances + with manual shard ranges and are considered required to use large bot + sharding by Discord, this function when used along IPC and a + before_identity_hook can speed up session start. + + .. versionadded:: 2.5 + + Returns + ------- + :class:`SessionStartLimits` + A class containing the session start limits + + Raises + ------ + GatewayNotFound + The gateway was unreachable + """ + _, _, limits = await self.http.get_bot_gateway() + return SessionStartLimits(**limits) + async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None: try: coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) @@ -434,7 +490,7 @@ class AutoShardedClient(Client): if self.shard_count is None: self.shard_count: int - self.shard_count, gateway_url = await self.http.get_bot_gateway() + self.shard_count, gateway_url, _session_start_limit = await self.http.get_bot_gateway() gateway = yarl.URL(gateway_url) else: gateway = DiscordWebSocket.DEFAULT_GATEWAY diff --git a/discord/sku.py b/discord/sku.py index 2af171c1d..3516370b4 100644 --- a/discord/sku.py +++ b/discord/sku.py @@ -25,16 +25,18 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import AsyncIterator, Optional, TYPE_CHECKING + +from datetime import datetime from . import utils -from .app_commands import MissingApplicationID from .enums import try_enum, SKUType, EntitlementType from .flags import SKUFlags +from .object import Object +from .subscription import Subscription if TYPE_CHECKING: - from datetime import datetime - + from .abc import SnowflakeTime, Snowflake from .guild import Guild from .state import ConnectionState from .types.sku import ( @@ -100,6 +102,149 @@ class SKU: """:class:`datetime.datetime`: Returns the sku's creation time in UTC.""" return utils.snowflake_time(self.id) + async def fetch_subscription(self, subscription_id: int, /) -> Subscription: + """|coro| + + Retrieves a :class:`.Subscription` with the specified ID. + + .. versionadded:: 2.5 + + Parameters + ----------- + subscription_id: :class:`int` + The subscription's ID to fetch from. + + Raises + ------- + NotFound + An subscription with this ID does not exist. + HTTPException + Fetching the subscription failed. + + Returns + -------- + :class:`.Subscription` + The subscription you requested. + """ + data = await self._state.http.get_sku_subscription(self.id, subscription_id) + return Subscription(data=data, state=self._state) + + async def subscriptions( + self, + *, + limit: Optional[int] = 50, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + user: Snowflake, + ) -> AsyncIterator[Subscription]: + """Retrieves an :term:`asynchronous iterator` of the :class:`.Subscription` that SKU has. + + .. versionadded:: 2.5 + + Examples + --------- + + Usage :: + + async for subscription in sku.subscriptions(limit=100, user=user): + print(subscription.user_id, subscription.current_period_end) + + Flattening into a list :: + + subscriptions = [subscription async for subscription in sku.subscriptions(limit=100, user=user)] + # subscriptions is now a list of Subscription... + + All parameters are optional. + + Parameters + ----------- + limit: Optional[:class:`int`] + The number of subscriptions to retrieve. If ``None``, it retrieves every subscription for this SKU. + Note, however, that this would make it a slow operation. Defaults to ``100``. + before: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve subscriptions before this date or entitlement. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + after: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve subscriptions after this date or entitlement. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + user: :class:`~discord.abc.Snowflake` + The user to filter by. + + Raises + ------- + HTTPException + Fetching the subscriptions failed. + TypeError + Both ``after`` and ``before`` were provided, as Discord does not + support this type of pagination. + + Yields + -------- + :class:`.Subscription` + The subscription with the SKU. + """ + + if before is not None and after is not None: + raise TypeError('subscriptions pagination does not support both before and after') + + # This endpoint paginates in ascending order. + endpoint = self._state.http.list_sku_subscriptions + + async def _before_strategy(retrieve: int, before: Optional[Snowflake], limit: Optional[int]): + before_id = before.id if before else None + data = await endpoint(self.id, before=before_id, limit=retrieve, user_id=user.id) + + if data: + if limit is not None: + limit -= len(data) + + before = Object(id=int(data[0]['id'])) + + return data, before, limit + + async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]): + after_id = after.id if after else None + data = await endpoint( + self.id, + after=after_id, + limit=retrieve, + user_id=user.id, + ) + + if data: + if limit is not None: + limit -= len(data) + + after = Object(id=int(data[-1]['id'])) + + return data, after, limit + + if isinstance(before, datetime): + before = Object(id=utils.time_snowflake(before, high=False)) + if isinstance(after, datetime): + after = Object(id=utils.time_snowflake(after, high=True)) + + if before: + strategy, state = _before_strategy, before + else: + strategy, state = _after_strategy, after + + while True: + retrieve = 100 if limit is None else min(limit, 100) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 100: + limit = 0 + + for e in data: + yield Subscription(data=e, state=self._state) + class Entitlement: """Represents an entitlement from user or guild which has been granted access to a premium offering. @@ -190,17 +335,12 @@ class Entitlement: Raises ------- - MissingApplicationID - The application ID could not be found. NotFound The entitlement could not be found. HTTPException Consuming the entitlement failed. """ - if self.application_id is None: - raise MissingApplicationID - await self._state.http.consume_entitlement(self.application_id, self.id) async def delete(self) -> None: @@ -210,15 +350,10 @@ class Entitlement: Raises ------- - MissingApplicationID - The application ID could not be found. NotFound The entitlement could not be found. HTTPException Deleting the entitlement failed. """ - if self.application_id is None: - raise MissingApplicationID - await self._state.http.delete_entitlement(self.application_id, self.id) diff --git a/discord/soundboard.py b/discord/soundboard.py new file mode 100644 index 000000000..3351aacb7 --- /dev/null +++ b/discord/soundboard.py @@ -0,0 +1,325 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from . import utils +from .mixins import Hashable +from .partial_emoji import PartialEmoji, _EmojiTag +from .user import User +from .utils import MISSING +from .asset import Asset, AssetMixin + +if TYPE_CHECKING: + import datetime + from typing import Dict, Any + + from .types.soundboard import ( + BaseSoundboardSound as BaseSoundboardSoundPayload, + SoundboardDefaultSound as SoundboardDefaultSoundPayload, + SoundboardSound as SoundboardSoundPayload, + ) + from .state import ConnectionState + from .guild import Guild + from .message import EmojiInputType + +__all__ = ('BaseSoundboardSound', 'SoundboardDefaultSound', 'SoundboardSound') + + +class BaseSoundboardSound(Hashable, AssetMixin): + """Represents a generic Discord soundboard sound. + + .. versionadded:: 2.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two sounds are equal. + + .. describe:: x != y + + Checks if two sounds are not equal. + + .. describe:: hash(x) + + Returns the sound's hash. + + Attributes + ------------ + id: :class:`int` + The ID of the sound. + volume: :class:`float` + The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%). + """ + + __slots__ = ('_state', 'id', 'volume') + + def __init__(self, *, state: ConnectionState, data: BaseSoundboardSoundPayload): + self._state: ConnectionState = state + self.id: int = int(data['sound_id']) + self._update(data) + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return self.id == other.id + return NotImplemented + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def _update(self, data: BaseSoundboardSoundPayload): + self.volume: float = data['volume'] + + @property + def url(self) -> str: + """:class:`str`: Returns the URL of the sound.""" + return f'{Asset.BASE}/soundboard-sounds/{self.id}' + + +class SoundboardDefaultSound(BaseSoundboardSound): + """Represents a Discord soundboard default sound. + + .. versionadded:: 2.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two sounds are equal. + + .. describe:: x != y + + Checks if two sounds are not equal. + + .. describe:: hash(x) + + Returns the sound's hash. + + Attributes + ------------ + id: :class:`int` + The ID of the sound. + volume: :class:`float` + The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%). + name: :class:`str` + The name of the sound. + emoji: :class:`PartialEmoji` + The emoji of the sound. + """ + + __slots__ = ('name', 'emoji') + + def __init__(self, *, state: ConnectionState, data: SoundboardDefaultSoundPayload): + self.name: str = data['name'] + self.emoji: PartialEmoji = PartialEmoji(name=data['emoji_name']) + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + attrs = [ + ('id', self.id), + ('name', self.name), + ('volume', self.volume), + ('emoji', self.emoji), + ] + inner = ' '.join('%s=%r' % t for t in attrs) + return f"<{self.__class__.__name__} {inner}>" + + +class SoundboardSound(BaseSoundboardSound): + """Represents a Discord soundboard sound. + + .. versionadded:: 2.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two sounds are equal. + + .. describe:: x != y + + Checks if two sounds are not equal. + + .. describe:: hash(x) + + Returns the sound's hash. + + Attributes + ------------ + id: :class:`int` + The ID of the sound. + volume: :class:`float` + The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%). + name: :class:`str` + The name of the sound. + emoji: Optional[:class:`PartialEmoji`] + The emoji of the sound. ``None`` if no emoji is set. + guild: :class:`Guild` + The guild in which the sound is uploaded. + available: :class:`bool` + Whether this sound is available for use. + """ + + __slots__ = ('_state', 'name', 'emoji', '_user', 'available', '_user_id', 'guild') + + def __init__(self, *, guild: Guild, state: ConnectionState, data: SoundboardSoundPayload): + super().__init__(state=state, data=data) + self.guild = guild + self._user_id = utils._get_as_snowflake(data, 'user_id') + self._user = data.get('user') + + self._update(data) + + def __repr__(self) -> str: + attrs = [ + ('id', self.id), + ('name', self.name), + ('volume', self.volume), + ('emoji', self.emoji), + ('user', self.user), + ] + inner = ' '.join('%s=%r' % t for t in attrs) + return f"<{self.__class__.__name__} {inner}>" + + def _update(self, data: SoundboardSoundPayload): + super()._update(data) + + self.name: str = data['name'] + self.emoji: Optional[PartialEmoji] = None + + emoji_id = utils._get_as_snowflake(data, 'emoji_id') + emoji_name = data['emoji_name'] + if emoji_id is not None or emoji_name is not None: + self.emoji = PartialEmoji(id=emoji_id, name=emoji_name) # type: ignore # emoji_name cannot be None here + + self.available: bool = data['available'] + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the snowflake's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def user(self) -> Optional[User]: + """Optional[:class:`User`]: The user who uploaded the sound.""" + if self._user is None: + if self._user_id is None: + return None + return self._state.get_user(self._user_id) + return User(state=self._state, data=self._user) + + async def edit( + self, + *, + name: str = MISSING, + volume: Optional[float] = MISSING, + emoji: Optional[EmojiInputType] = MISSING, + reason: Optional[str] = None, + ): + """|coro| + + Edits the soundboard sound. + + You must have :attr:`~Permissions.manage_expressions` to edit the sound. + If the sound was created by the client, you must have either :attr:`~Permissions.manage_expressions` + or :attr:`~Permissions.create_expressions`. + + Parameters + ---------- + name: :class:`str` + The new name of the sound. Must be between 2 and 32 characters. + volume: Optional[:class:`float`] + The new volume of the sound. Must be between 0 and 1. + emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`]] + The new emoji of the sound. + reason: Optional[:class:`str`] + The reason for editing this sound. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to edit the soundboard sound. + HTTPException + Editing the soundboard sound failed. + + Returns + ------- + :class:`SoundboardSound` + The newly updated soundboard sound. + """ + + payload: Dict[str, Any] = {} + + if name is not MISSING: + payload['name'] = name + + if volume is not MISSING: + payload['volume'] = volume + + if emoji is not MISSING: + if emoji is None: + payload['emoji_id'] = None + payload['emoji_name'] = None + else: + if isinstance(emoji, _EmojiTag): + partial_emoji = emoji._to_partial() + elif isinstance(emoji, str): + partial_emoji = PartialEmoji.from_str(emoji) + else: + partial_emoji = None + + if partial_emoji is not None: + if partial_emoji.id is None: + payload['emoji_name'] = partial_emoji.name + else: + payload['emoji_id'] = partial_emoji.id + + data = await self._state.http.edit_soundboard_sound(self.guild.id, self.id, reason=reason, **payload) + return SoundboardSound(guild=self.guild, state=self._state, data=data) + + async def delete(self, *, reason: Optional[str] = None) -> None: + """|coro| + + Deletes the soundboard sound. + + You must have :attr:`~Permissions.manage_expressions` to delete the sound. + If the sound was created by the client, you must have either :attr:`~Permissions.manage_expressions` + or :attr:`~Permissions.create_expressions`. + + Parameters + ----------- + reason: Optional[:class:`str`] + The reason for deleting this sound. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to delete the soundboard sound. + HTTPException + Deleting the soundboard sound failed. + """ + await self._state.http.delete_soundboard_sound(self.guild.id, self.id, reason=reason) diff --git a/discord/state.py b/discord/state.py index db0395266..8dad83a88 100644 --- a/discord/state.py +++ b/discord/state.py @@ -78,6 +78,9 @@ from .sticker import GuildSticker from .automod import AutoModRule, AutoModAction from .audit_logs import AuditLogEntry from ._types import ClientT +from .soundboard import SoundboardSound +from .subscription import Subscription + if TYPE_CHECKING: from .abc import PrivateChannel @@ -455,6 +458,14 @@ class ConnectionState(Generic[ClientT]): def stickers(self) -> Sequence[GuildSticker]: return utils.SequenceProxy(self._stickers.values()) + @property + def soundboard_sounds(self) -> List[SoundboardSound]: + all_sounds = [] + for guild in self.guilds: + all_sounds.extend(guild.soundboard_sounds) + + return all_sounds + def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]: # the keys of self._emojis are ints return self._emojis.get(emoji_id) # type: ignore @@ -510,12 +521,6 @@ class ConnectionState(Generic[ClientT]): def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None - def _get_poll(self, msg_id: Optional[int]) -> Optional[Poll]: - message = self._get_message(msg_id) - if not message: - return - return message.poll - def _add_guild_from_data(self, data: GuildPayload) -> Guild: guild = Guild(data=data, state=self) self._add_guild(guild) @@ -685,17 +690,21 @@ class ConnectionState(Generic[ClientT]): self._messages.remove(msg) # type: ignore def parse_message_update(self, data: gw.MessageUpdateEvent) -> None: - raw = RawMessageUpdateEvent(data) - message = self._get_message(raw.message_id) - if message is not None: - older_message = copy.copy(message) + channel, _ = self._get_guild_channel(data) + # channel would be the correct type here + updated_message = Message(channel=channel, data=data, state=self) # type: ignore + + raw = RawMessageUpdateEvent(data=data, message=updated_message) + cached_message = self._get_message(updated_message.id) + if cached_message is not None: + older_message = copy.copy(cached_message) raw.cached_message = older_message self.dispatch('raw_message_edit', raw) - message._update(data) + cached_message._update(data) # Coerce the `after` parameter to take the new updated Member # ref: #5999 - older_message.author = message.author - self.dispatch('message_edit', older_message, message) + older_message.author = updated_message.author + self.dispatch('message_edit', older_message, updated_message) else: self.dispatch('raw_message_edit', raw) @@ -1561,6 +1570,63 @@ class ConnectionState(Generic[ClientT]): else: _log.debug('SCHEDULED_EVENT_USER_REMOVE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + def parse_guild_soundboard_sound_create(self, data: gw.GuildSoundBoardSoundCreateEvent) -> None: + guild_id = int(data['guild_id']) # type: ignore # can't be None here + guild = self._get_guild(guild_id) + if guild is not None: + sound = SoundboardSound(guild=guild, state=self, data=data) + guild._add_soundboard_sound(sound) + self.dispatch('soundboard_sound_create', sound) + else: + _log.debug('GUILD_SOUNDBOARD_SOUND_CREATE referencing unknown guild ID: %s. Discarding.', guild_id) + + def _update_and_dispatch_sound_update(self, sound: SoundboardSound, data: gw.GuildSoundBoardSoundUpdateEvent): + old_sound = copy.copy(sound) + sound._update(data) + self.dispatch('soundboard_sound_update', old_sound, sound) + + def parse_guild_soundboard_sound_update(self, data: gw.GuildSoundBoardSoundUpdateEvent) -> None: + guild_id = int(data['guild_id']) # type: ignore # can't be None here + guild = self._get_guild(guild_id) + if guild is not None: + sound_id = int(data['sound_id']) + sound = guild.get_soundboard_sound(sound_id) + if sound is not None: + self._update_and_dispatch_sound_update(sound, data) + else: + _log.warning('GUILD_SOUNDBOARD_SOUND_UPDATE referencing unknown sound ID: %s. Discarding.', sound_id) + else: + _log.debug('GUILD_SOUNDBOARD_SOUND_UPDATE referencing unknown guild ID: %s. Discarding.', guild_id) + + def parse_guild_soundboard_sound_delete(self, data: gw.GuildSoundBoardSoundDeleteEvent) -> None: + guild_id = int(data['guild_id']) + guild = self._get_guild(guild_id) + if guild is not None: + sound_id = int(data['sound_id']) + sound = guild.get_soundboard_sound(sound_id) + if sound is not None: + guild._remove_soundboard_sound(sound) + self.dispatch('soundboard_sound_delete', sound) + else: + _log.warning('GUILD_SOUNDBOARD_SOUND_DELETE referencing unknown sound ID: %s. Discarding.', sound_id) + else: + _log.debug('GUILD_SOUNDBOARD_SOUND_DELETE referencing unknown guild ID: %s. Discarding.', guild_id) + + def parse_guild_soundboard_sounds_update(self, data: gw.GuildSoundBoardSoundsUpdateEvent) -> None: + guild_id = int(data['guild_id']) + guild = self._get_guild(guild_id) + if guild is None: + _log.debug('GUILD_SOUNDBOARD_SOUNDS_UPDATE referencing unknown guild ID: %s. Discarding.', guild_id) + return + + for raw_sound in data['soundboard_sounds']: + sound_id = int(raw_sound['sound_id']) + sound = guild.get_soundboard_sound(sound_id) + if sound is not None: + self._update_and_dispatch_sound_update(sound, raw_sound) + else: + _log.warning('GUILD_SOUNDBOARD_SOUNDS_UPDATE referencing unknown sound ID: %s. Discarding.', sound_id) + def parse_application_command_permissions_update(self, data: GuildApplicationCommandPermissionsPayload): raw = RawAppCommandPermissionsUpdateEvent(data=data, state=self) self.dispatch('raw_app_command_permissions_update', raw) @@ -1591,6 +1657,14 @@ class ConnectionState(Generic[ClientT]): else: _log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) + def parse_voice_channel_effect_send(self, data: gw.VoiceChannelEffectSendEvent): + guild = self._get_guild(int(data['guild_id'])) + if guild is not None: + effect = VoiceChannelEffect(state=self, data=data, guild=guild) + self.dispatch('voice_channel_effect', effect) + else: + _log.debug('VOICE_CHANNEL_EFFECT_SEND referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + def parse_voice_server_update(self, data: gw.VoiceServerUpdateEvent) -> None: key_id = int(data['guild_id']) @@ -1669,6 +1743,18 @@ class ConnectionState(Generic[ClientT]): if poll: self.dispatch('poll_vote_remove', user, poll.get_answer(raw.answer_id)) + def parse_subscription_create(self, data: gw.SubscriptionCreateEvent) -> None: + subscription = Subscription(data=data, state=self) + self.dispatch('subscription_create', subscription) + + def parse_subscription_update(self, data: gw.SubscriptionUpdateEvent) -> None: + subscription = Subscription(data=data, state=self) + self.dispatch('subscription_update', subscription) + + def parse_subscription_delete(self, data: gw.SubscriptionDeleteEvent) -> None: + subscription = Subscription(data=data, state=self) + self.dispatch('subscription_delete', subscription) + def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, (TextChannel, Thread, VoiceChannel)): return channel.guild.get_member(user_id) @@ -1713,6 +1799,15 @@ class ConnectionState(Generic[ClientT]): def create_message(self, *, channel: MessageableChannel, data: MessagePayload) -> Message: return Message(state=self, channel=channel, data=data) + def get_soundboard_sound(self, id: Optional[int]) -> Optional[SoundboardSound]: + if id is None: + return + + for guild in self.guilds: + sound = guild._resolve_soundboard_sound(id) + if sound is not None: + return sound + class AutoShardedConnectionState(ConnectionState[ClientT]): def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/discord/subscription.py b/discord/subscription.py new file mode 100644 index 000000000..ec6d7c3e5 --- /dev/null +++ b/discord/subscription.py @@ -0,0 +1,107 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import datetime +from typing import List, Optional, TYPE_CHECKING + +from . import utils +from .mixins import Hashable +from .enums import try_enum, SubscriptionStatus + +if TYPE_CHECKING: + from .state import ConnectionState + from .types.subscription import Subscription as SubscriptionPayload + from .user import User + +__all__ = ('Subscription',) + + +class Subscription(Hashable): + """Represents a Discord subscription. + + .. versionadded:: 2.5 + + Attributes + ----------- + id: :class:`int` + The subscription's ID. + user_id: :class:`int` + The ID of the user that is subscribed. + sku_ids: List[:class:`int`] + The IDs of the SKUs that the user subscribed to. + entitlement_ids: List[:class:`int`] + The IDs of the entitlements granted for this subscription. + current_period_start: :class:`datetime.datetime` + When the current billing period started. + current_period_end: :class:`datetime.datetime` + When the current billing period ends. + status: :class:`SubscriptionStatus` + The status of the subscription. + canceled_at: Optional[:class:`datetime.datetime`] + When the subscription was canceled. + This is only available for subscriptions with a :attr:`status` of :attr:`SubscriptionStatus.inactive`. + renewal_sku_ids: List[:class:`int`] + The IDs of the SKUs that the user is going to be subscribed to when renewing. + """ + + __slots__ = ( + '_state', + 'id', + 'user_id', + 'sku_ids', + 'entitlement_ids', + 'current_period_start', + 'current_period_end', + 'status', + 'canceled_at', + 'renewal_sku_ids', + ) + + def __init__(self, *, state: ConnectionState, data: SubscriptionPayload): + self._state = state + + self.id: int = int(data['id']) + self.user_id: int = int(data['user_id']) + self.sku_ids: List[int] = list(map(int, data['sku_ids'])) + self.entitlement_ids: List[int] = list(map(int, data['entitlement_ids'])) + self.current_period_start: datetime.datetime = utils.parse_time(data['current_period_start']) + self.current_period_end: datetime.datetime = utils.parse_time(data['current_period_end']) + self.status: SubscriptionStatus = try_enum(SubscriptionStatus, data['status']) + self.canceled_at: Optional[datetime.datetime] = utils.parse_time(data['canceled_at']) + self.renewal_sku_ids: List[int] = list(map(int, data['renewal_sku_ids'] or [])) + + def __repr__(self) -> str: + return f'' + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the subscription's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def user(self) -> Optional[User]: + """Optional[:class:`User`]: The user that is subscribed.""" + return self._state.get_user(self.user_id) diff --git a/discord/types/audit_log.py b/discord/types/audit_log.py index cd949709a..2c37542fd 100644 --- a/discord/types/audit_log.py +++ b/discord/types/audit_log.py @@ -88,6 +88,9 @@ AuditLogEvent = Literal[ 111, 112, 121, + 130, + 131, + 132, 140, 141, 142, @@ -112,6 +115,7 @@ class _AuditLogChange_Str(TypedDict): 'permissions', 'tags', 'unicode_emoji', + 'emoji_name', ] new_value: str old_value: str @@ -136,6 +140,8 @@ class _AuditLogChange_Snowflake(TypedDict): 'channel_id', 'inviter_id', 'guild_id', + 'user_id', + 'sound_id', ] new_value: Snowflake old_value: Snowflake @@ -183,6 +189,12 @@ class _AuditLogChange_Int(TypedDict): old_value: int +class _AuditLogChange_Float(TypedDict): + key: Literal['volume'] + new_value: float + old_value: float + + class _AuditLogChange_ListRole(TypedDict): key: Literal['$add', '$remove'] new_value: List[Role] @@ -290,6 +302,7 @@ AuditLogChange = Union[ _AuditLogChange_AssetHash, _AuditLogChange_Snowflake, _AuditLogChange_Int, + _AuditLogChange_Float, _AuditLogChange_Bool, _AuditLogChange_ListRole, _AuditLogChange_MFALevel, diff --git a/discord/types/channel.py b/discord/types/channel.py index d5d82b5c6..4b593e554 100644 --- a/discord/types/channel.py +++ b/discord/types/channel.py @@ -28,6 +28,7 @@ from typing_extensions import NotRequired from .user import PartialUser from .snowflake import Snowflake from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration, ThreadType +from .emoji import PartialEmoji OverwriteType = Literal[0, 1] @@ -89,6 +90,20 @@ class VoiceChannel(_BaseTextChannel): video_quality_mode: NotRequired[VideoQualityMode] +VoiceChannelEffectAnimationType = Literal[0, 1] + + +class VoiceChannelEffect(TypedDict): + guild_id: Snowflake + channel_id: Snowflake + user_id: Snowflake + emoji: NotRequired[Optional[PartialEmoji]] + animation_type: NotRequired[VoiceChannelEffectAnimationType] + animation_id: NotRequired[int] + sound_id: NotRequired[Union[int, str]] + sound_volume: NotRequired[float] + + class CategoryChannel(_BaseGuildChannel): type: Literal[4] diff --git a/discord/types/emoji.py b/discord/types/emoji.py index d54690c14..85e709757 100644 --- a/discord/types/emoji.py +++ b/discord/types/emoji.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from typing import Optional, TypedDict +from typing_extensions import NotRequired from .snowflake import Snowflake, SnowflakeList from .user import User @@ -30,6 +31,7 @@ from .user import User class PartialEmoji(TypedDict): id: Optional[Snowflake] name: Optional[str] + animated: NotRequired[bool] class Emoji(PartialEmoji, total=False): diff --git a/discord/types/gateway.py b/discord/types/gateway.py index ff43a5f25..7dca5badc 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -31,7 +31,7 @@ from .sku import Entitlement from .voice import GuildVoiceState from .integration import BaseIntegration, IntegrationApplication from .role import Role -from .channel import ChannelType, StageInstance +from .channel import ChannelType, StageInstance, VoiceChannelEffect from .interactions import Interaction from .invite import InviteTargetType from .emoji import Emoji, PartialEmoji @@ -45,6 +45,8 @@ from .user import User, AvatarDecorationData from .threads import Thread, ThreadMember from .scheduled_event import GuildScheduledEvent from .audit_log import AuditLogEntry +from .soundboard import SoundboardSound +from .subscription import Subscription class SessionStartLimit(TypedDict): @@ -90,8 +92,7 @@ class MessageDeleteBulkEvent(TypedDict): guild_id: NotRequired[Snowflake] -class MessageUpdateEvent(Message): - channel_id: Snowflake +MessageUpdateEvent = MessageCreateEvent class MessageReactionAddEvent(TypedDict): @@ -319,6 +320,19 @@ class _GuildScheduledEventUsersEvent(TypedDict): GuildScheduledEventUserAdd = GuildScheduledEventUserRemove = _GuildScheduledEventUsersEvent VoiceStateUpdateEvent = GuildVoiceState +VoiceChannelEffectSendEvent = VoiceChannelEffect + +GuildSoundBoardSoundCreateEvent = GuildSoundBoardSoundUpdateEvent = SoundboardSound + + +class GuildSoundBoardSoundsUpdateEvent(TypedDict): + guild_id: Snowflake + soundboard_sounds: List[SoundboardSound] + + +class GuildSoundBoardSoundDeleteEvent(TypedDict): + sound_id: Snowflake + guild_id: Snowflake class VoiceServerUpdateEvent(TypedDict): @@ -362,3 +376,6 @@ class PollVoteActionEvent(TypedDict): message_id: Snowflake guild_id: NotRequired[Snowflake] answer_id: int + + +SubscriptionCreateEvent = SubscriptionUpdateEvent = SubscriptionDeleteEvent = Subscription diff --git a/discord/types/guild.py b/discord/types/guild.py index ba43fbf96..e0a1f3e54 100644 --- a/discord/types/guild.py +++ b/discord/types/guild.py @@ -37,6 +37,7 @@ from .member import Member from .emoji import Emoji from .user import User from .threads import Thread +from .soundboard import SoundboardSound class Ban(TypedDict): @@ -90,6 +91,8 @@ GuildFeature = Literal[ 'VIP_REGIONS', 'WELCOME_SCREEN_ENABLED', 'RAID_ALERTS_DISABLED', + 'SOUNDBOARD', + 'MORE_SOUNDBOARD', ] @@ -154,6 +157,7 @@ class Guild(_BaseGuildPreview): max_members: NotRequired[int] premium_subscription_count: NotRequired[int] max_video_channel_users: NotRequired[int] + soundboard_sounds: NotRequired[List[SoundboardSound]] class InviteGuild(Guild, total=False): diff --git a/discord/types/interactions.py b/discord/types/interactions.py index 7aac5df7d..a72a5b2ce 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -255,11 +255,49 @@ class MessageInteraction(TypedDict): member: NotRequired[Member] -class MessageInteractionMetadata(TypedDict): +class _MessageInteractionMetadata(TypedDict): id: Snowflake - type: InteractionType user: User authorizing_integration_owners: Dict[Literal['0', '1'], Snowflake] original_response_message_id: NotRequired[Snowflake] - interacted_message_id: NotRequired[Snowflake] - triggering_interaction_metadata: NotRequired[MessageInteractionMetadata] + + +class _ApplicationCommandMessageInteractionMetadata(_MessageInteractionMetadata): + type: Literal[2] + # command_type: Literal[1, 2, 3, 4] + + +class UserApplicationCommandMessageInteractionMetadata(_ApplicationCommandMessageInteractionMetadata): + # command_type: Literal[2] + target_user: User + + +class MessageApplicationCommandMessageInteractionMetadata(_ApplicationCommandMessageInteractionMetadata): + # command_type: Literal[3] + target_message_id: Snowflake + + +ApplicationCommandMessageInteractionMetadata = Union[ + _ApplicationCommandMessageInteractionMetadata, + UserApplicationCommandMessageInteractionMetadata, + MessageApplicationCommandMessageInteractionMetadata, +] + + +class MessageComponentMessageInteractionMetadata(_MessageInteractionMetadata): + type: Literal[3] + interacted_message_id: Snowflake + + +class ModalSubmitMessageInteractionMetadata(_MessageInteractionMetadata): + type: Literal[5] + triggering_interaction_metadata: Union[ + ApplicationCommandMessageInteractionMetadata, MessageComponentMessageInteractionMetadata + ] + + +MessageInteractionMetadata = Union[ + ApplicationCommandMessageInteractionMetadata, + MessageComponentMessageInteractionMetadata, + ModalSubmitMessageInteractionMetadata, +] diff --git a/discord/types/message.py b/discord/types/message.py index bdb3f10ef..1ec86681b 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -102,7 +102,11 @@ class MessageApplication(TypedDict): cover_image: NotRequired[str] +MessageReferenceType = Literal[0, 1] + + class MessageReference(TypedDict, total=False): + type: MessageReferenceType message_id: Snowflake channel_id: Required[Snowflake] guild_id: Snowflake @@ -116,6 +120,24 @@ class RoleSubscriptionData(TypedDict): is_renewal: bool +PurchaseNotificationResponseType = Literal[0] + + +class GuildProductPurchase(TypedDict): + listing_id: Snowflake + product_name: str + + +class PurchaseNotificationResponse(TypedDict): + type: PurchaseNotificationResponseType + guild_product_purchase: Optional[GuildProductPurchase] + + +class CallMessage(TypedDict): + participants: SnowflakeList + ended_timestamp: NotRequired[Optional[str]] + + MessageType = Literal[ 0, 1, @@ -151,9 +173,24 @@ MessageType = Literal[ 37, 38, 39, + 44, ] +class MessageSnapshot(TypedDict): + type: MessageType + content: str + embeds: List[Embed] + attachments: List[Attachment] + timestamp: str + edited_timestamp: Optional[str] + flags: NotRequired[int] + mentions: List[UserWithMember] + mention_roles: SnowflakeList + sticker_items: NotRequired[List[StickerItem]] + components: NotRequired[List[Component]] + + class Message(PartialMessage): id: Snowflake author: User @@ -187,6 +224,8 @@ class Message(PartialMessage): position: NotRequired[int] role_subscription_data: NotRequired[RoleSubscriptionData] thread: NotRequired[Thread] + call: NotRequired[CallMessage] + purchase_notification: NotRequired[PurchaseNotificationResponse] AllowedMentionType = Literal['roles', 'users', 'everyone'] diff --git a/discord/types/soundboard.py b/discord/types/soundboard.py new file mode 100644 index 000000000..4910df808 --- /dev/null +++ b/discord/types/soundboard.py @@ -0,0 +1,49 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from typing import TypedDict, Optional, Union +from typing_extensions import NotRequired + +from .snowflake import Snowflake +from .user import User + + +class BaseSoundboardSound(TypedDict): + sound_id: Union[Snowflake, str] # basic string number when it's a default sound + volume: float + + +class SoundboardSound(BaseSoundboardSound): + name: str + emoji_name: Optional[str] + emoji_id: Optional[Snowflake] + user_id: NotRequired[Snowflake] + available: bool + guild_id: NotRequired[Snowflake] + user: NotRequired[User] + + +class SoundboardDefaultSound(BaseSoundboardSound): + name: str + emoji_name: str diff --git a/discord/types/subscription.py b/discord/types/subscription.py new file mode 100644 index 000000000..8d4c02070 --- /dev/null +++ b/discord/types/subscription.py @@ -0,0 +1,43 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import List, Literal, Optional, TypedDict + +from .snowflake import Snowflake + +SubscriptionStatus = Literal[0, 1, 2] + + +class Subscription(TypedDict): + id: Snowflake + user_id: Snowflake + sku_ids: List[Snowflake] + entitlement_ids: List[Snowflake] + current_period_start: str + current_period_end: str + status: SubscriptionStatus + canceled_at: Optional[str] + renewal_sku_ids: Optional[List[Snowflake]] diff --git a/discord/types/voice.py b/discord/types/voice.py index 8f4e2e03e..7e856ecdd 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -29,7 +29,12 @@ from .snowflake import Snowflake from .member import MemberWithUser -SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305'] +SupportedModes = Literal[ + 'aead_xchacha20_poly1305_rtpsize', + 'xsalsa20_poly1305_lite', + 'xsalsa20_poly1305_suffix', + 'xsalsa20_poly1305', +] class _VoiceState(TypedDict): diff --git a/discord/utils.py b/discord/utils.py index 89cc8bdeb..9b6bd59a2 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import array @@ -41,7 +42,6 @@ from typing import ( Iterator, List, Literal, - Mapping, NamedTuple, Optional, Protocol, @@ -71,6 +71,7 @@ import types import typing import warnings import logging +import zlib import yarl @@ -81,6 +82,12 @@ except ModuleNotFoundError: else: HAS_ORJSON = True +try: + import zstandard # type: ignore +except ImportError: + _HAS_ZSTD = False +else: + _HAS_ZSTD = True __all__ = ( 'oauth_url', @@ -148,8 +155,11 @@ if TYPE_CHECKING: from .invite import Invite from .template import Template - class _RequestLike(Protocol): - headers: Mapping[str, Any] + class _DecompressionContext(Protocol): + COMPRESSION_TYPE: str + + def decompress(self, data: bytes, /) -> str | None: + ... P = ParamSpec('P') @@ -317,7 +327,7 @@ def oauth_url( permissions: Permissions = MISSING, guild: Snowflake = MISSING, redirect_uri: str = MISSING, - scopes: Iterable[str] = MISSING, + scopes: Optional[Iterable[str]] = MISSING, disable_guild_select: bool = False, state: str = MISSING, ) -> str: @@ -359,7 +369,8 @@ def oauth_url( The OAuth2 URL for inviting the bot into guilds. """ url = f'https://discord.com/oauth2/authorize?client_id={client_id}' - url += '&scope=' + '+'.join(scopes or ('bot', 'applications.commands')) + if scopes is not None: + url += '&scope=' + '+'.join(scopes or ('bot', 'applications.commands')) if permissions is not MISSING: url += f'&permissions={permissions.value}' if guild is not MISSING: @@ -623,9 +634,19 @@ def _get_mime_type_for_image(data: bytes): raise ValueError('Unsupported image type given') -def _bytes_to_base64_data(data: bytes) -> str: +def _get_mime_type_for_audio(data: bytes): + if data.startswith(b'\x49\x44\x33') or data.startswith(b'\xff\xfb'): + return 'audio/mpeg' + else: + raise ValueError('Unsupported audio type given') + + +def _bytes_to_base64_data(data: bytes, *, audio: bool = False) -> str: fmt = 'data:{mime};base64,{data}' - mime = _get_mime_type_for_image(data) + if audio: + mime = _get_mime_type_for_audio(data) + else: + mime = _get_mime_type_for_image(data) b64 = b64encode(data).decode('ascii') return fmt.format(mime=mime, data=b64) @@ -848,6 +869,12 @@ def resolve_invite(invite: Union[Invite, str]) -> ResolvedInvite: invite: Union[:class:`~discord.Invite`, :class:`str`] The invite. + Raises + ------- + ValueError + The invite is not a valid Discord invite, e.g. is not a URL + or does not contain alphanumeric characters. + Returns -------- :class:`.ResolvedInvite` @@ -867,7 +894,12 @@ def resolve_invite(invite: Union[Invite, str]) -> ResolvedInvite: event_id = url.query.get('event') return ResolvedInvite(code, int(event_id) if event_id else None) - return ResolvedInvite(invite, None) + + allowed_characters = r'[a-zA-Z0-9\-_]+' + if not re.fullmatch(allowed_characters, invite): + raise ValueError('Invite contains characters that are not allowed') + + return ResolvedInvite(invite, None) def resolve_template(code: Union[Template, str]) -> str: @@ -1406,3 +1438,97 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o return f'{seq[0]} {final} {seq[1]}' return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' + + +if _HAS_ZSTD: + + class _ZstdDecompressionContext: + __slots__ = ('context',) + + COMPRESSION_TYPE: str = 'zstd-stream' + + def __init__(self) -> None: + decompressor = zstandard.ZstdDecompressor() + self.context = decompressor.decompressobj() + + def decompress(self, data: bytes, /) -> str | None: + # Each WS message is a complete gateway message + return self.context.decompress(data).decode('utf-8') + + _ActiveDecompressionContext: Type[_DecompressionContext] = _ZstdDecompressionContext +else: + + class _ZlibDecompressionContext: + __slots__ = ('context', 'buffer') + + COMPRESSION_TYPE: str = 'zlib-stream' + + def __init__(self) -> None: + self.buffer: bytearray = bytearray() + self.context = zlib.decompressobj() + + def decompress(self, data: bytes, /) -> str | None: + self.buffer.extend(data) + + # Check whether ending is Z_SYNC_FLUSH + if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': + return + + msg = self.context.decompress(self.buffer) + self.buffer = bytearray() + + return msg.decode('utf-8') + + _ActiveDecompressionContext: Type[_DecompressionContext] = _ZlibDecompressionContext + + +def _format_call_duration(duration: datetime.timedelta) -> str: + seconds = duration.total_seconds() + + minutes_s = 60 + hours_s = minutes_s * 60 + days_s = hours_s * 24 + # Discord uses approx. 1/12 of 365.25 days (avg. days per year) + months_s = days_s * 30.4375 + years_s = months_s * 12 + + threshold_s = 45 + threshold_m = 45 + threshold_h = 21.5 + threshold_d = 25.5 + threshold_M = 10.5 + + if seconds < threshold_s: + formatted = "a few seconds" + elif seconds < (threshold_m * minutes_s): + minutes = round(seconds / minutes_s) + if minutes == 1: + formatted = "a minute" + else: + formatted = f"{minutes} minutes" + elif seconds < (threshold_h * hours_s): + hours = round(seconds / hours_s) + if hours == 1: + formatted = "an hour" + else: + formatted = f"{hours} hours" + elif seconds < (threshold_d * days_s): + days = round(seconds / days_s) + if days == 1: + formatted = "a day" + else: + formatted = f"{days} days" + elif seconds < (threshold_M * months_s): + months = round(seconds / months_s) + if months == 1: + formatted = "a month" + else: + formatted = f"{months} months" + else: + years = round(seconds / years_s) + if years == 1: + formatted = "a year" + else: + formatted = f"{years} years" + + return formatted diff --git a/discord/voice_client.py b/discord/voice_client.py index 3e1c6a5ff..795434e1e 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -230,12 +230,13 @@ class VoiceClient(VoiceProtocol): self.timestamp: int = 0 self._player: Optional[AudioPlayer] = None self.encoder: Encoder = MISSING - self._lite_nonce: int = 0 + self._incr_nonce: int = 0 self._connection: VoiceConnectionState = self.create_connection_state() warn_nacl: bool = not has_nacl supported_modes: Tuple[SupportedModes, ...] = ( + 'aead_xchacha20_poly1305_rtpsize', 'xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305', @@ -380,7 +381,21 @@ class VoiceClient(VoiceProtocol): encrypt_packet = getattr(self, '_encrypt_' + self.mode) return encrypt_packet(header, data) + def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes: + # Esentially the same as _lite + # Uses an incrementing 32-bit integer which is appended to the payload + # The only other difference is we require AEAD with Additional Authenticated Data (the header) + box = nacl.secret.Aead(bytes(self.secret_key)) + nonce = bytearray(24) + + nonce[:4] = struct.pack('>I', self._incr_nonce) + self.checked_add('_incr_nonce', 1, 4294967295) + + return header + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext + nonce[:4] + def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes: + # Deprecated. Removal: 18th Nov 2024. See: + # https://discord.com/developers/docs/topics/voice-connections#transport-encryption-modes box = nacl.secret.SecretBox(bytes(self.secret_key)) nonce = bytearray(24) nonce[:12] = header @@ -388,17 +403,21 @@ class VoiceClient(VoiceProtocol): return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes: + # Deprecated. Removal: 18th Nov 2024. See: + # https://discord.com/developers/docs/topics/voice-connections#transport-encryption-modes box = nacl.secret.SecretBox(bytes(self.secret_key)) nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) return header + box.encrypt(bytes(data), nonce).ciphertext + nonce def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes: + # Deprecated. Removal: 18th Nov 2024. See: + # https://discord.com/developers/docs/topics/voice-connections#transport-encryption-modes box = nacl.secret.SecretBox(bytes(self.secret_key)) nonce = bytearray(24) - nonce[:4] = struct.pack('>I', self._lite_nonce) - self.checked_add('_lite_nonce', 1, 4294967295) + nonce[:4] = struct.pack('>I', self._incr_nonce) + self.checked_add('_incr_nonce', 1, 4294967295) return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 91b021a5e..2d9856ae3 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -721,11 +721,6 @@ class _WebhookState: return self._parent._get_guild(guild_id) return None - def _get_poll(self, msg_id: Optional[int]) -> Optional[Poll]: - if self._parent is not None: - return self._parent._get_poll(msg_id) - return None - def store_user(self, data: Union[UserPayload, PartialUserPayload], *, cache: bool = True) -> BaseUser: if self._parent is not None: return self._parent.store_user(data, cache=cache) diff --git a/docs/api.rst b/docs/api.rst index bfc08836b..fd636c86f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1298,6 +1298,35 @@ Scheduled Events :type user: :class:`User` +Soundboard +~~~~~~~~~~~ + +.. function:: on_soundboard_sound_create(sound) + on_soundboard_sound_delete(sound) + + Called when a :class:`SoundboardSound` is created or deleted. + + .. versionadded:: 2.5 + + :param sound: The soundboard sound that was created or deleted. + :type sound: :class:`SoundboardSound` + +.. function:: on_soundboard_sound_update(before, after) + + Called when a :class:`SoundboardSound` is updated. + + The following examples illustrate when this event is called: + + - The name is changed. + - The emoji is changed. + - The volume is changed. + + .. versionadded:: 2.5 + + :param sound: The soundboard sound that was updated. + :type sound: :class:`SoundboardSound` + + Stages ~~~~~~~ @@ -1327,6 +1356,37 @@ Stages :param after: The stage instance after the update. :type after: :class:`StageInstance` + +Subscriptions +~~~~~~~~~~~~~ + +.. function:: on_subscription_create(subscription) + + Called when a subscription is created. + + .. versionadded:: 2.5 + + :param subscription: The subscription that was created. + :type subscription: :class:`Subscription` + +.. function:: on_subscription_update(subscription) + + Called when a subscription is updated. + + .. versionadded:: 2.5 + + :param subscription: The subscription that was updated. + :type subscription: :class:`Subscription` + +.. function:: on_subscription_delete(subscription) + + Called when a subscription is deleted. + + .. versionadded:: 2.5 + + :param subscription: The subscription that was deleted. + :type subscription: :class:`Subscription` + Threads ~~~~~~~~ @@ -1483,6 +1543,17 @@ Voice :param after: The voice state after the changes. :type after: :class:`VoiceState` +.. function:: on_voice_channel_effect(effect) + + Called when a :class:`Member` sends a :class:`VoiceChannelEffect` in a voice channel the bot is in. + + This requires :attr:`Intents.voice_states` to be enabled. + + .. versionadded:: 2.5 + + :param effect: The effect that is sent. + :type effect: :class:`VoiceChannelEffect` + .. _discord-api-utils: Utility Functions @@ -1810,6 +1881,12 @@ of :class:`enum.Enum`. .. versionadded:: 2.4 + .. attribute:: purchase_notification + + The system message sent when a purchase is made in the guild. + + .. versionadded:: 2.5 + .. class:: UserFlags Represents Discord User flags. @@ -2945,6 +3022,42 @@ of :class:`enum.Enum`. .. versionadded:: 2.4 + .. attribute:: soundboard_sound_create + + A soundboard sound was created. + + Possible attributes for :class:`AuditLogDiff`: + + - :attr:`~AuditLogDiff.name` + - :attr:`~AuditLogDiff.emoji` + - :attr:`~AuditLogDiff.volume` + + .. versionadded:: 2.5 + + .. attribute:: soundboard_sound_update + + A soundboard sound was updated. + + Possible attributes for :class:`AuditLogDiff`: + + - :attr:`~AuditLogDiff.name` + - :attr:`~AuditLogDiff.emoji` + - :attr:`~AuditLogDiff.volume` + + .. versionadded:: 2.5 + + .. attribute:: soundboard_sound_delete + + A soundboard sound was deleted. + + Possible attributes for :class:`AuditLogDiff`: + + - :attr:`~AuditLogDiff.name` + - :attr:`~AuditLogDiff.emoji` + - :attr:`~AuditLogDiff.volume` + + .. versionadded:: 2.5 + .. class:: AuditLogActionCategory Represents the category that the :class:`AuditLogAction` belongs to. @@ -3722,6 +3835,58 @@ of :class:`enum.Enum`. The ``6`` day of the week. +.. class:: VoiceChannelEffectAnimationType + + Represents the animation type of a voice channel effect. + + .. versionadded:: 2.5 + + .. attribute:: premium + + A fun animation, sent by a Nitro subscriber. + + .. attribute:: basic + + The standard animation. + + +.. class:: SubscriptionStatus + + Represents the status of an subscription. + + .. versionadded:: 2.5 + + .. attribute:: active + + The subscription is active. + + .. attribute:: ending + + The subscription is active but will not renew. + + .. attribute:: inactive + + The subscription is inactive and not being charged. + + +.. class:: MessageReferenceType + + Represents the type of a message reference. + + .. versionadded:: 2.5 + + .. attribute:: reply + + A message reply. + + .. attribute:: forward + + A forwarded message. + + .. attribute:: default + + An alias for :attr:`.reply`. + .. _discord-api-audit-logs: Audit Log Data @@ -4187,11 +4352,12 @@ AuditLogDiff .. attribute:: emoji - The name of the emoji that represents a sticker being changed. + The emoji which represents one of the following: - See also :attr:`GuildSticker.emoji`. + * :attr:`GuildSticker.emoji` + * :attr:`SoundboardSound.emoji` - :type: :class:`str` + :type: Union[:class:`str`, :class:`PartialEmoji`] .. attribute:: unicode_emoji @@ -4212,9 +4378,10 @@ AuditLogDiff .. attribute:: available - The availability of a sticker being changed. + The availability of one of the following being changed: - See also :attr:`GuildSticker.available` + * :attr:`GuildSticker.available` + * :attr:`SoundboardSound.available` :type: :class:`bool` @@ -4437,6 +4604,22 @@ AuditLogDiff :type: Optional[:class:`PartialEmoji`] + .. attribute:: user + + The user that represents the uploader of a soundboard sound. + + See also :attr:`SoundboardSound.user` + + :type: Union[:class:`Member`, :class:`User`] + + .. attribute:: volume + + The volume of a soundboard sound. + + See also :attr:`SoundboardSound.volume` + + :type: :class:`float` + .. this is currently missing the following keys: reason and application_id I'm not sure how to port these @@ -4691,6 +4874,13 @@ Guild :type: List[:class:`Object`] +GuildPreview +~~~~~~~~~~~~ + +.. attributetable:: GuildPreview + +.. autoclass:: GuildPreview + :members: ScheduledEvent ~~~~~~~~~~~~~~ @@ -4858,6 +5048,35 @@ VoiceChannel :members: :inherited-members: +.. attributetable:: VoiceChannelEffect + +.. autoclass:: VoiceChannelEffect() + :members: + :inherited-members: + +.. class:: VoiceChannelEffectAnimation + + A namedtuple which represents a voice channel effect animation. + + .. versionadded:: 2.5 + + .. attribute:: id + + The ID of the animation. + + :type: :class:`int` + .. attribute:: type + + The type of the animation. + + :type: :class:`VoiceChannelEffectAnimationType` + +.. attributetable:: VoiceChannelSoundEffect + +.. autoclass:: VoiceChannelSoundEffect() + :members: + :inherited-members: + StageChannel ~~~~~~~~~~~~~ @@ -5024,6 +5243,30 @@ GuildSticker .. autoclass:: GuildSticker() :members: +BaseSoundboardSound +~~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: BaseSoundboardSound + +.. autoclass:: BaseSoundboardSound() + :members: + +SoundboardDefaultSound +~~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: SoundboardDefaultSound + +.. autoclass:: SoundboardDefaultSound() + :members: + +SoundboardSound +~~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: SoundboardSound + +.. autoclass:: SoundboardSound() + :members: + ShardInfo ~~~~~~~~~~~ @@ -5032,6 +5275,14 @@ ShardInfo .. autoclass:: ShardInfo() :members: +SessionStartLimits +~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: SessionStartLimits + +.. autoclass:: SessionStartLimits() + :members: + SKU ~~~~~~~~~~~ @@ -5048,6 +5299,14 @@ Entitlement .. autoclass:: Entitlement() :members: +Subscription +~~~~~~~~~~~~ + +.. attributetable:: Subscription + +.. autoclass:: Subscription() + :members: + RawMessageDeleteEvent ~~~~~~~~~~~~~~~~~~~~~~~ @@ -5186,6 +5445,14 @@ PollAnswer .. _discord_api_data: +MessageSnapshot +~~~~~~~~~~~~~~~~~ + +.. attributetable:: MessageSnapshot + +.. autoclass:: MessageSnapshot + :members: + Data Classes -------------- @@ -5257,6 +5524,22 @@ RoleSubscriptionInfo .. autoclass:: RoleSubscriptionInfo :members: +PurchaseNotification +~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: PurchaseNotification + +.. autoclass:: PurchaseNotification() + :members: + +GuildProductPurchase ++++++++++++++++++++++ + +.. attributetable:: GuildProductPurchase + +.. autoclass:: GuildProductPurchase() + :members: + Intents ~~~~~~~~~~ @@ -5465,6 +5748,14 @@ PollMedia .. autoclass:: PollMedia :members: +CallMessage +~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: CallMessage + +.. autoclass:: CallMessage() + :members: + ScheduledEventRecurrenceRule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -5508,6 +5799,8 @@ The following exceptions are thrown by the library. .. autoexception:: InteractionResponded +.. autoexception:: MissingApplicationID + .. autoexception:: discord.opus.OpusError .. autoexception:: discord.opus.OpusNotLoaded @@ -5525,6 +5818,7 @@ Exception Hierarchy - :exc:`ConnectionClosed` - :exc:`PrivilegedIntentsRequired` - :exc:`InteractionResponded` + - :exc:`MissingApplicationID` - :exc:`GatewayNotFound` - :exc:`HTTPException` - :exc:`Forbidden` diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 9bda24f6e..f55225614 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -708,6 +708,9 @@ Exceptions .. autoexception:: discord.ext.commands.ScheduledEventNotFound :members: +.. autoexception:: discord.ext.commands.SoundboardSoundNotFound + :members: + .. autoexception:: discord.ext.commands.BadBoolArgument :members: @@ -800,6 +803,7 @@ Exception Hierarchy - :exc:`~.commands.EmojiNotFound` - :exc:`~.commands.GuildStickerNotFound` - :exc:`~.commands.ScheduledEventNotFound` + - :exc:`~.commands.SoundboardSoundNotFound` - :exc:`~.commands.PartialEmojiConversionFailure` - :exc:`~.commands.BadBoolArgument` - :exc:`~.commands.RangeError` diff --git a/docs/interactions/api.rst b/docs/interactions/api.rst index 211cd790f..aeb6a25c6 100644 --- a/docs/interactions/api.rst +++ b/docs/interactions/api.rst @@ -872,9 +872,6 @@ Exceptions .. autoexception:: discord.app_commands.CommandNotFound :members: -.. autoexception:: discord.app_commands.MissingApplicationID - :members: - .. autoexception:: discord.app_commands.CommandSyncFailure :members: @@ -899,7 +896,7 @@ Exception Hierarchy - :exc:`~discord.app_commands.CommandAlreadyRegistered` - :exc:`~discord.app_commands.CommandSignatureMismatch` - :exc:`~discord.app_commands.CommandNotFound` - - :exc:`~discord.app_commands.MissingApplicationID` + - :exc:`~discord.MissingApplicationID` - :exc:`~discord.app_commands.CommandSyncFailure` - :exc:`~discord.HTTPException` - :exc:`~discord.app_commands.CommandSyncFailure` diff --git a/pyproject.toml b/pyproject.toml index 596e6ef08..d7360731d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,12 +50,15 @@ docs = [ "sphinxcontrib-serializinghtml==1.1.5", "typing-extensions>=4.3,<5", "sphinx-inline-tabs==2023.4.21", + # TODO: Remove this when moving to Sphinx >= 6.6 + "imghdr-lts==1.0.0; python_version>='3.13'", ] speed = [ "orjson>=3.5.4", "aiodns>=1.1; sys_platform != 'win32'", "Brotli", "cchardet==2.1.7; python_version < '3.10'", + "zstandard>=0.23.0" ] test = [ "coverage[toml]", @@ -66,6 +69,10 @@ test = [ "typing-extensions>=4.3,<5", "tzdata; sys_platform == 'win32'", ] +dev = [ + "black==22.6", + "typing_extensions>=4.3,<5", +] [tool.setuptools] packages = [ diff --git a/tests/test_colour.py b/tests/test_colour.py index bf0e59713..b79f153f0 100644 --- a/tests/test_colour.py +++ b/tests/test_colour.py @@ -44,6 +44,7 @@ import pytest ('rgb(20%, 24%, 56%)', 0x333D8F), ('rgb(20%, 23.9%, 56.1%)', 0x333D8F), ('rgb(51, 61, 143)', 0x333D8F), + ('0x#333D8F', 0x333D8F), ], ) def test_from_str(value, expected): @@ -53,6 +54,7 @@ def test_from_str(value, expected): @pytest.mark.parametrize( ('value'), [ + None, 'not valid', '0xYEAH', '#YEAH', @@ -62,8 +64,72 @@ def test_from_str(value, expected): 'rgb(30, -1, 60)', 'invalid(a, b, c)', 'rgb(', + '#1000000', + '#FFFFFFF', + "rgb(101%, 50%, 50%)", + "rgb(50%, -10%, 50%)", + "rgb(50%, 50%, 150%)", + "rgb(256, 100, 100)", ], ) def test_from_str_failures(value): with pytest.raises(ValueError): discord.Colour.from_str(value) + + +@pytest.mark.parametrize( + ('value', 'expected'), + [ + (discord.Colour.default(), 0x000000), + (discord.Colour.teal(), 0x1ABC9C), + (discord.Colour.dark_teal(), 0x11806A), + (discord.Colour.brand_green(), 0x57F287), + (discord.Colour.green(), 0x2ECC71), + (discord.Colour.dark_green(), 0x1F8B4C), + (discord.Colour.blue(), 0x3498DB), + (discord.Colour.dark_blue(), 0x206694), + (discord.Colour.purple(), 0x9B59B6), + (discord.Colour.dark_purple(), 0x71368A), + (discord.Colour.magenta(), 0xE91E63), + (discord.Colour.dark_magenta(), 0xAD1457), + (discord.Colour.gold(), 0xF1C40F), + (discord.Colour.dark_gold(), 0xC27C0E), + (discord.Colour.orange(), 0xE67E22), + (discord.Colour.dark_orange(), 0xA84300), + (discord.Colour.brand_red(), 0xED4245), + (discord.Colour.red(), 0xE74C3C), + (discord.Colour.dark_red(), 0x992D22), + (discord.Colour.lighter_grey(), 0x95A5A6), + (discord.Colour.dark_grey(), 0x607D8B), + (discord.Colour.light_grey(), 0x979C9F), + (discord.Colour.darker_grey(), 0x546E7A), + (discord.Colour.og_blurple(), 0x7289DA), + (discord.Colour.blurple(), 0x5865F2), + (discord.Colour.greyple(), 0x99AAB5), + (discord.Colour.dark_theme(), 0x313338), + (discord.Colour.fuchsia(), 0xEB459E), + (discord.Colour.yellow(), 0xFEE75C), + (discord.Colour.dark_embed(), 0x2B2D31), + (discord.Colour.light_embed(), 0xEEEFF1), + (discord.Colour.pink(), 0xEB459F), + ], +) +def test_static_colours(value, expected): + assert value.value == expected + + + + +@pytest.mark.parametrize( + ('value', 'property', 'expected'), + [ + (discord.Colour(0x000000), 'r', 0), + (discord.Colour(0xFFFFFF), 'g', 255), + (discord.Colour(0xABCDEF), 'b', 239), + (discord.Colour(0x44243B), 'r', 68), + (discord.Colour(0x333D8F), 'g', 61), + (discord.Colour(0xDBFF00), 'b', 0), + ], +) +def test_colour_properties(value, property, expected): + assert getattr(value, property) == expected diff --git a/tests/test_embed.py b/tests/test_embed.py new file mode 100644 index 000000000..3efedd6a5 --- /dev/null +++ b/tests/test_embed.py @@ -0,0 +1,269 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import datetime + +import discord +import pytest + + +@pytest.mark.parametrize( + ('title', 'description', 'colour', 'url'), + [ + ('title', 'description', 0xABCDEF, 'https://example.com'), + ('title', 'description', 0xFF1294, None), + ('title', 'description', discord.Colour(0x333D8F), 'https://example.com'), + ('title', 'description', discord.Colour(0x44243B), None), + ], +) +def test_embed_initialization(title, description, colour, url): + embed = discord.Embed(title=title, description=description, colour=colour, url=url) + assert embed.title == title + assert embed.description == description + assert embed.colour == colour or embed.colour == discord.Colour(colour) + assert embed.url == url + + +@pytest.mark.parametrize( + ('text', 'icon_url'), + [ + ('Hello discord.py', 'https://example.com'), + ('text', None), + (None, 'https://example.com'), + (None, None), + ], +) +def test_embed_set_footer(text, icon_url): + embed = discord.Embed() + embed.set_footer(text=text, icon_url=icon_url) + assert embed.footer.text == text + assert embed.footer.icon_url == icon_url + + +def test_embed_remove_footer(): + embed = discord.Embed() + embed.set_footer(text='Hello discord.py', icon_url='https://example.com') + embed.remove_footer() + assert embed.footer.text is None + assert embed.footer.icon_url is None + + +@pytest.mark.parametrize( + ('name', 'url', 'icon_url'), + [ + ('Rapptz', 'http://example.com', 'http://example.com/icon.png'), + ('NCPlayz', None, 'http://example.com/icon.png'), + ('Jackenmen', 'http://example.com', None), + ], +) +def test_embed_set_author(name, url, icon_url): + embed = discord.Embed() + embed.set_author(name=name, url=url, icon_url=icon_url) + assert embed.author.name == name + assert embed.author.url == url + assert embed.author.icon_url == icon_url + + +def test_embed_remove_author(): + embed = discord.Embed() + embed.set_author(name='Rapptz', url='http://example.com', icon_url='http://example.com/icon.png') + embed.remove_author() + assert embed.author.name is None + assert embed.author.url is None + assert embed.author.icon_url is None + + +@pytest.mark.parametrize( + ('thumbnail'), + [ + ('http://example.com'), + (None), + ], +) +def test_embed_set_thumbnail(thumbnail): + embed = discord.Embed() + embed.set_thumbnail(url=thumbnail) + assert embed.thumbnail.url == thumbnail + + +@pytest.mark.parametrize( + ('image'), + [ + ('http://example.com'), + (None), + ], +) +def test_embed_set_image(image): + embed = discord.Embed() + embed.set_image(url=image) + assert embed.image.url == image + + +@pytest.mark.parametrize( + ('name', 'value', 'inline'), + [ + ('music', 'music value', True), + ('sport', 'sport value', False), + ], +) +def test_embed_add_field(name, value, inline): + embed = discord.Embed() + embed.add_field(name=name, value=value, inline=inline) + assert len(embed.fields) == 1 + assert embed.fields[0].name == name + assert embed.fields[0].value == value + assert embed.fields[0].inline == inline + + +def test_embed_insert_field(): + embed = discord.Embed() + embed.add_field(name='name', value='value', inline=True) + embed.insert_field_at(0, name='name 2', value='value 2', inline=False) + assert embed.fields[0].name == 'name 2' + assert embed.fields[0].value == 'value 2' + assert embed.fields[0].inline is False + + +def test_embed_set_field_at(): + embed = discord.Embed() + embed.add_field(name='name', value='value', inline=True) + embed.set_field_at(0, name='name 2', value='value 2', inline=False) + assert embed.fields[0].name == 'name 2' + assert embed.fields[0].value == 'value 2' + assert embed.fields[0].inline is False + + +def test_embed_set_field_at_failure(): + embed = discord.Embed() + with pytest.raises(IndexError): + embed.set_field_at(0, name='name', value='value', inline=True) + + +def test_embed_clear_fields(): + embed = discord.Embed() + embed.add_field(name="field 1", value="value 1", inline=False) + embed.add_field(name="field 2", value="value 2", inline=False) + embed.add_field(name="field 3", value="value 3", inline=False) + embed.clear_fields() + assert len(embed.fields) == 0 + + +def test_embed_remove_field(): + embed = discord.Embed() + embed.add_field(name='name', value='value', inline=True) + embed.remove_field(0) + assert len(embed.fields) == 0 + + +@pytest.mark.parametrize( + ('title', 'description', 'url'), + [ + ('title 1', 'description 1', 'https://example.com'), + ('title 2', 'description 2', None), + ], +) +def test_embed_copy(title, description, url): + embed = discord.Embed(title=title, description=description, url=url) + embed_copy = embed.copy() + + assert embed == embed_copy + assert embed.title == embed_copy.title + assert embed.description == embed_copy.description + assert embed.url == embed_copy.url + + +@pytest.mark.parametrize( + ('title', 'description'), + [ + ('title 1', 'description 1'), + ('title 2', 'description 2'), + ], +) +def test_embed_len(title, description): + embed = discord.Embed(title=title, description=description) + assert len(embed) == len(title) + len(description) + + +@pytest.mark.parametrize( + ('title', 'description', 'fields', 'footer', 'author'), + [ + ( + 'title 1', + 'description 1', + [('field name 1', 'field value 1'), ('field name 2', 'field value 2')], + 'footer 1', + 'author 1', + ), + ('title 2', 'description 2', [('field name 3', 'field value 3')], 'footer 2', 'author 2'), + ], +) +def test_embed_len_with_options(title, description, fields, footer, author): + embed = discord.Embed(title=title, description=description) + for name, value in fields: + embed.add_field(name=name, value=value) + embed.set_footer(text=footer) + embed.set_author(name=author) + assert len(embed) == len(title) + len(description) + len("".join([name + value for name, value in fields])) + len( + footer + ) + len(author) + + +def test_embed_to_dict(): + timestamp = datetime.datetime.now(datetime.timezone.utc) + embed = discord.Embed(title="Test Title", description="Test Description", timestamp=timestamp) + data = embed.to_dict() + assert data['title'] == "Test Title" + assert data['description'] == "Test Description" + assert data['timestamp'] == timestamp.isoformat() + + +def test_embed_from_dict(): + data = { + 'title': 'Test Title', + 'description': 'Test Description', + 'url': 'http://example.com', + 'color': 0x00FF00, + 'timestamp': '2024-07-03T12:34:56+00:00', + } + embed = discord.Embed.from_dict(data) + assert embed.title == 'Test Title' + assert embed.description == 'Test Description' + assert embed.url == 'http://example.com' + assert embed.colour is not None and embed.colour.value == 0x00FF00 + assert embed.timestamp is not None and embed.timestamp.isoformat() == '2024-07-03T12:34:56+00:00' + + +@pytest.mark.parametrize( + ('value'), + [ + -0.5, + '#FFFFFF', + ], +) +def test_embed_colour_setter_failure(value): + embed = discord.Embed() + with pytest.raises(TypeError): + embed.colour = value diff --git a/tests/test_files.py b/tests/test_files.py index 6096c3a38..72ff3b7b3 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -27,6 +27,7 @@ from __future__ import annotations from io import BytesIO import discord +import pytest FILE = BytesIO() @@ -127,3 +128,58 @@ def test_file_not_spoiler_with_overriding_name_double_spoiler(): f.filename = 'SPOILER_SPOILER_.gitignore' assert f.filename == 'SPOILER_.gitignore' assert f.spoiler == True + + +def test_file_reset(): + f = discord.File('.gitignore') + + f.reset(seek=True) + assert f.fp.tell() == 0 + + f.reset(seek=False) + assert f.fp.tell() == 0 + + +def test_io_reset(): + f = discord.File(FILE) + + f.reset(seek=True) + assert f.fp.tell() == 0 + + f.reset(seek=False) + assert f.fp.tell() == 0 + + +def test_io_failure(): + class NonSeekableReadable(BytesIO): + def seekable(self): + return False + + def readable(self): + return False + + f = NonSeekableReadable() + + with pytest.raises(ValueError) as excinfo: + discord.File(f) + + assert str(excinfo.value) == f"File buffer {f!r} must be seekable and readable" + + +def test_io_to_dict(): + buffer = BytesIO(b"test content") + file = discord.File(buffer, filename="test.txt", description="test description") + + data = file.to_dict(0) + assert data["id"] == 0 + assert data["filename"] == "test.txt" + assert data["description"] == "test description" + + +def test_file_to_dict(): + f = discord.File('.gitignore', description="test description") + + data = f.to_dict(0) + assert data["id"] == 0 + assert data["filename"] == ".gitignore" + assert data["description"] == "test description" diff --git a/tests/test_permissions_all.py b/tests/test_permissions_all.py new file mode 100644 index 000000000..883dc1b63 --- /dev/null +++ b/tests/test_permissions_all.py @@ -0,0 +1,7 @@ +import discord + +from functools import reduce +from operator import or_ + +def test_permissions_all(): + assert discord.Permissions.all().value == reduce(or_, discord.Permissions.VALID_FLAGS.values()) diff --git a/tests/test_ui_buttons.py b/tests/test_ui_buttons.py new file mode 100644 index 000000000..55c0c7cd8 --- /dev/null +++ b/tests/test_ui_buttons.py @@ -0,0 +1,167 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import discord +import pytest + + +def test_button_init(): + button = discord.ui.Button( + label="Click me!", + ) + assert button.label == "Click me!" + assert button.style == discord.ButtonStyle.secondary + assert button.disabled == False + assert button.url == None + assert button.emoji == None + assert button.sku_id == None + + +def test_button_with_sku_id(): + button = discord.ui.Button( + label="Click me!", + sku_id=1234567890, + ) + assert button.label == "Click me!" + assert button.style == discord.ButtonStyle.premium + assert button.sku_id == 1234567890 + + +def test_button_with_url(): + button = discord.ui.Button( + label="Click me!", + url="https://example.com", + ) + assert button.label == "Click me!" + assert button.style == discord.ButtonStyle.link + assert button.url == "https://example.com" + + +def test_mix_both_custom_id_and_url(): + with pytest.raises(TypeError): + discord.ui.Button( + label="Click me!", + url="https://example.com", + custom_id="test", + ) + + +def test_mix_both_custom_id_and_sku_id(): + with pytest.raises(TypeError): + discord.ui.Button( + label="Click me!", + sku_id=1234567890, + custom_id="test", + ) + + +def test_mix_both_url_and_sku_id(): + with pytest.raises(TypeError): + discord.ui.Button( + label="Click me!", + url="https://example.com", + sku_id=1234567890, + ) + + +def test_invalid_url(): + button = discord.ui.Button( + label="Click me!", + ) + with pytest.raises(TypeError): + button.url = 1234567890 # type: ignore + + +def test_invalid_custom_id(): + with pytest.raises(TypeError): + discord.ui.Button( + label="Click me!", + custom_id=1234567890, # type: ignore + ) + + button = discord.ui.Button( + label="Click me!", + ) + with pytest.raises(TypeError): + button.custom_id = 1234567890 # type: ignore + + +def test_button_with_partial_emoji(): + button = discord.ui.Button( + label="Click me!", + emoji="👍", + ) + assert button.label == "Click me!" + assert button.emoji is not None and button.emoji.name == "👍" + + +def test_button_with_str_emoji(): + emoji = discord.PartialEmoji(name="👍") + button = discord.ui.Button( + label="Click me!", + emoji=emoji, + ) + assert button.label == "Click me!" + assert button.emoji == emoji + + +def test_button_with_invalid_emoji(): + with pytest.raises(TypeError): + discord.ui.Button( + label="Click me!", + emoji=-0.53, # type: ignore + ) + + button = discord.ui.Button( + label="Click me!", + ) + with pytest.raises(TypeError): + button.emoji = -0.53 # type: ignore + + +def test_button_setter(): + button = discord.ui.Button() + + button.label = "Click me!" + assert button.label == "Click me!" + + button.style = discord.ButtonStyle.primary + assert button.style == discord.ButtonStyle.primary + + button.disabled = True + assert button.disabled == True + + button.url = "https://example.com" + assert button.url == "https://example.com" + + button.emoji = "👍" + assert button.emoji is not None and button.emoji.name == "👍" # type: ignore + + button.custom_id = "test" + assert button.custom_id == "test" + + button.sku_id = 1234567890 + assert button.sku_id == 1234567890 diff --git a/tests/test_ui_modals.py b/tests/test_ui_modals.py new file mode 100644 index 000000000..dd1ac7169 --- /dev/null +++ b/tests/test_ui_modals.py @@ -0,0 +1,102 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import discord +import pytest + + +@pytest.mark.asyncio +async def test_modal_init(): + modal = discord.ui.Modal( + title="Temp Title", + ) + assert modal.title == "Temp Title" + assert modal.timeout == None + + +@pytest.mark.asyncio +async def test_no_title(): + with pytest.raises(ValueError) as excinfo: + discord.ui.Modal() + + assert str(excinfo.value) == "Modal must have a title" + + +@pytest.mark.asyncio +async def test_to_dict(): + modal = discord.ui.Modal( + title="Temp Title", + ) + data = modal.to_dict() + assert data["custom_id"] is not None + assert data["title"] == "Temp Title" + assert data["components"] == [] + + +@pytest.mark.asyncio +async def test_add_item(): + modal = discord.ui.Modal( + title="Temp Title", + ) + item = discord.ui.TextInput(label="Test") + modal.add_item(item) + + assert modal.children == [item] + + +@pytest.mark.asyncio +async def test_add_item_invalid(): + modal = discord.ui.Modal( + title="Temp Title", + ) + with pytest.raises(TypeError): + modal.add_item("Not an item") # type: ignore + + +@pytest.mark.asyncio +async def test_maximum_items(): + modal = discord.ui.Modal( + title="Temp Title", + ) + max_item_limit = 5 + + for i in range(max_item_limit): + modal.add_item(discord.ui.TextInput(label=f"Test {i}")) + + with pytest.raises(ValueError): + modal.add_item(discord.ui.TextInput(label="Test")) + + +@pytest.mark.asyncio +async def test_modal_setters(): + modal = discord.ui.Modal( + title="Temp Title", + ) + modal.title = "New Title" + assert modal.title == "New Title" + + modal.timeout = 120 + assert modal.timeout == 120