Browse Source

Merge remote-tracking branch 'upstream/master' into feature/guild/onboarding

pull/10226/head
Soheab_ 2 weeks ago
parent
commit
c986598fce
  1. 2
      .github/workflows/lint.yml
  2. 5
      .readthedocs.yml
  3. 2
      MANIFEST.in
  4. 7
      README.rst
  5. 15
      discord/__init__.py
  6. 10
      discord/__main__.py
  7. 155
      discord/abc.py
  8. 48
      discord/activity.py
  9. 1
      discord/app_commands/__init__.py
  10. 2
      discord/app_commands/checks.py
  11. 480
      discord/app_commands/commands.py
  12. 64
      discord/app_commands/errors.py
  13. 213
      discord/app_commands/installs.py
  14. 110
      discord/app_commands/models.py
  15. 4
      discord/app_commands/namespace.py
  16. 59
      discord/app_commands/transformers.py
  17. 2
      discord/app_commands/translator.py
  18. 78
      discord/app_commands/tree.py
  19. 327
      discord/appinfo.py
  20. 26
      discord/asset.py
  21. 189
      discord/audit_logs.py
  22. 156
      discord/automod.py
  23. 596
      discord/channel.py
  24. 878
      discord/client.py
  25. 103
      discord/colour.py
  26. 155
      discord/components.py
  27. 80
      discord/embeds.py
  28. 49
      discord/emoji.py
  29. 362
      discord/enums.py
  30. 25
      discord/errors.py
  31. 40
      discord/ext/commands/bot.py
  32. 45
      discord/ext/commands/cog.py
  33. 222
      discord/ext/commands/context.py
  34. 182
      discord/ext/commands/converter.py
  35. 4
      discord/ext/commands/cooldowns.py
  36. 39
      discord/ext/commands/core.py
  37. 82
      discord/ext/commands/errors.py
  38. 45
      discord/ext/commands/flags.py
  39. 15
      discord/ext/commands/help.py
  40. 67
      discord/ext/commands/hybrid.py
  41. 51
      discord/ext/commands/parameters.py
  42. 22
      discord/ext/tasks/__init__.py
  43. 2
      discord/file.py
  44. 765
      discord/flags.py
  45. 167
      discord/gateway.py
  46. 1137
      discord/guild.py
  47. 446
      discord/http.py
  48. 422
      discord/interactions.py
  49. 31
      discord/invite.py
  50. 232
      discord/member.py
  51. 930
      discord/message.py
  52. 2
      discord/object.py
  53. 2
      discord/oggparse.py
  54. 68
      discord/opus.py
  55. 2
      discord/partial_emoji.py
  56. 195
      discord/permissions.py
  57. 163
      discord/player.py
  58. 672
      discord/poll.py
  59. 150
      discord/presences.py
  60. 148
      discord/raw_models.py
  61. 42
      discord/reaction.py
  62. 121
      discord/role.py
  63. 120
      discord/scheduled_event.py
  64. 96
      discord/shard.py
  65. 359
      discord/sku.py
  66. 325
      discord/soundboard.py
  67. 298
      discord/state.py
  68. 32
      discord/sticker.py
  69. 107
      discord/subscription.py
  70. 21
      discord/team.py
  71. 27
      discord/template.py
  72. 26
      discord/threads.py
  73. 1
      discord/types/activity.py
  74. 31
      discord/types/appinfo.py
  75. 24
      discord/types/audit_log.py
  76. 7
      discord/types/automod.py
  77. 44
      discord/types/channel.py
  78. 4
      discord/types/command.py
  79. 14
      discord/types/components.py
  80. 26
      discord/types/embed.py
  81. 2
      discord/types/emoji.py
  82. 46
      discord/types/gateway.py
  83. 24
      discord/types/guild.py
  84. 113
      discord/types/interactions.py
  85. 4
      discord/types/invite.py
  86. 8
      discord/types/member.py
  87. 98
      discord/types/message.py
  88. 88
      discord/types/poll.py
  89. 1
      discord/types/role.py
  90. 53
      discord/types/sku.py
  91. 49
      discord/types/soundboard.py
  92. 4
      discord/types/sticker.py
  93. 43
      discord/types/subscription.py
  94. 3
      discord/types/team.py
  95. 12
      discord/types/user.py
  96. 7
      discord/types/voice.py
  97. 1
      discord/ui/__init__.py
  98. 47
      discord/ui/button.py
  99. 216
      discord/ui/dynamic.py
  100. 35
      discord/ui/item.py

2
.github/workflows/lint.yml

@ -38,7 +38,7 @@ jobs:
- name: Run Pyright
uses: jakebailey/pyright-action@v1
with:
version: '1.1.289'
version: '1.1.394'
warnings: false
no-comments: ${{ matrix.python-version != '3.x' }}

5
.readthedocs.yml

@ -2,7 +2,9 @@ version: 2
formats: []
build:
image: latest
os: "ubuntu-22.04"
tools:
python: "3.8"
sphinx:
configuration: docs/conf.py
@ -10,7 +12,6 @@ sphinx:
builder: html
python:
version: 3.8
install:
- method: pip
path: .

2
MANIFEST.in

@ -1,5 +1,5 @@
include README.rst
include LICENSE
include requirements.txt
include discord/bin/*.dll
include discord/bin/*
include discord/py.typed

7
README.rst

@ -27,6 +27,13 @@ Installing
To install the library without full voice support, you can just run the following command:
.. note::
A `Virtual Environment <https://docs.python.org/3/library/venv.html>`__ is recommended to install
the library, especially on Linux where the system Python is externally managed and restricts which
packages you can install on it.
.. code:: sh
# Linux/macOS

15
discord/__init__.py

@ -13,7 +13,7 @@ __title__ = 'discord'
__author__ = 'Rapptz'
__license__ = 'MIT'
__copyright__ = 'Copyright 2015-present Rapptz'
__version__ = '2.2.0a'
__version__ = '2.6.0a'
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
@ -41,6 +41,7 @@ from .integrations import *
from .invite import *
from .template import *
from .welcome_screen import *
from .sku import *
from .widget import *
from .object import *
from .reaction import *
@ -68,6 +69,10 @@ from .interactions import *
from .components import *
from .threads import *
from .automod import *
from .poll import *
from .soundboard import *
from .subscription import *
from .presences import *
class VersionInfo(NamedTuple):
@ -78,8 +83,14 @@ class VersionInfo(NamedTuple):
serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=2, micro=0, releaselevel='alpha', serial=0)
version_info: VersionInfo = VersionInfo(major=2, minor=6, micro=0, releaselevel='alpha', serial=0)
logging.getLogger(__name__).addHandler(logging.NullHandler())
# This is a backwards compatibility hack and should be removed in v3
# Essentially forcing the exception to have different base classes
# In the future, this should only inherit from ClientException
if len(MissingApplicationID.__bases__) == 1:
MissingApplicationID.__bases__ = (app_commands.AppCommandError, ClientException)
del logging, NamedTuple, Literal, VersionInfo

10
discord/__main__.py

@ -28,7 +28,7 @@ from typing import Optional, Tuple, Dict
import argparse
import sys
from pathlib import Path
from pathlib import Path, PurePath, PureWindowsPath
import discord
import importlib.metadata
@ -225,8 +225,14 @@ def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool
)
if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one')
path = PurePath(name)
if isinstance(path, PureWindowsPath) and path.drive:
drive, rest = path.parts[0], path.parts[1:]
transformed = tuple(map(lambda p: p.translate(_translation_table), rest))
name = drive + '\\'.join(transformed)
name = name.translate(_translation_table)
else:
name = name.translate(_translation_table)
if replace_spaces:
name = name.replace(' ', '-')
return Path(name)

155
discord/abc.py

@ -26,6 +26,7 @@ from __future__ import annotations
import copy
import time
import secrets
import asyncio
from datetime import datetime
from typing import (
@ -48,8 +49,8 @@ from typing import (
from .object import OLDEST_OBJECT, Object
from .context_managers import Typing
from .enums import ChannelType
from .errors import ClientException
from .enums import ChannelType, InviteTarget
from .errors import ClientException, NotFound
from .mentions import AllowedMentions
from .permissions import PermissionOverwrite, Permissions
from .role import Role
@ -59,6 +60,7 @@ from .http import handle_message_parameters
from .voice_client import VoiceClient, VoiceProtocol
from .sticker import GuildSticker, StickerItem
from . import utils
from .flags import InviteFlags
__all__ = (
'Snowflake',
@ -83,9 +85,17 @@ if TYPE_CHECKING:
from .channel import CategoryChannel
from .embeds import Embed
from .message import Message, MessageReference, PartialMessage
from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable, VoiceChannel
from .channel import (
TextChannel,
DMChannel,
GroupChannel,
PartialMessageable,
VocalGuildChannel,
VoiceChannel,
StageChannel,
)
from .poll import Poll
from .threads import Thread
from .enums import InviteTarget
from .ui.view import View
from .types.channel import (
PermissionOverwrite as PermissionOverwritePayload,
@ -93,11 +103,14 @@ if TYPE_CHECKING:
GuildChannel as GuildChannelPayload,
OverwriteType,
)
from .types.guild import (
ChannelPositionUpdate,
)
from .types.snowflake import (
SnowflakeList,
)
PartialMessageableChannel = Union[TextChannel, VoiceChannel, Thread, DMChannel, PartialMessageable]
PartialMessageableChannel = Union[TextChannel, VoiceChannel, StageChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]
@ -114,11 +127,18 @@ _undefined: Any = _Undefined()
async def _single_delete_strategy(messages: Iterable[Message], *, reason: Optional[str] = None):
for m in messages:
await m.delete()
try:
await m.delete()
except NotFound as exc:
if exc.code == 10008:
continue # bulk deletion ignores not found messages, single deletion does not.
# several other race conditions with deletion should fail without continuing,
# such as the channel being deleted and not found.
raise
async def _purge_helper(
channel: Union[Thread, TextChannel, VoiceChannel],
channel: Union[Thread, TextChannel, VocalGuildChannel],
*,
limit: Optional[int] = 100,
check: Callable[[Message], bool] = MISSING,
@ -211,7 +231,9 @@ class User(Snowflake, Protocol):
name: :class:`str`
The user's username.
discriminator: :class:`str`
The user's discriminator.
The user's discriminator. This is a legacy concept that is no longer used.
global_name: Optional[:class:`str`]
The user's global nickname.
bot: :class:`bool`
If the user is a bot account.
system: :class:`bool`
@ -220,6 +242,7 @@ class User(Snowflake, Protocol):
name: str
discriminator: str
global_name: Optional[str]
bot: bool
system: bool
@ -238,9 +261,25 @@ class User(Snowflake, Protocol):
"""Optional[:class:`~discord.Asset`]: Returns an Asset that represents the user's avatar, if present."""
raise NotImplementedError
@property
def avatar_decoration(self) -> Optional[Asset]:
"""Optional[:class:`~discord.Asset`]: Returns an Asset that represents the user's avatar decoration, if present.
.. versionadded:: 2.4
"""
raise NotImplementedError
@property
def avatar_decoration_sku_id(self) -> Optional[int]:
"""Optional[:class:`int`]: Returns an integer that represents the user's avatar decoration SKU ID, if present.
.. versionadded:: 2.4
"""
raise NotImplementedError
@property
def default_avatar(self) -> Asset:
""":class:`~discord.Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
""":class:`~discord.Asset`: Returns the default avatar for a given user."""
raise NotImplementedError
@property
@ -490,6 +529,13 @@ class GuildChannel:
raise TypeError('type field must be of type ChannelType')
options['type'] = ch_type.value
try:
status = options.pop('status')
except KeyError:
pass
else:
await self._state.http.edit_voice_channel_status(status, channel_id=self.id, reason=reason)
if options:
return await self._state.http.edit_channel(self.id, reason=reason, **options)
@ -665,6 +711,7 @@ class GuildChannel:
- Member overrides
- Implicit permissions
- Member timeout
- User installed app
If a :class:`~discord.Role` is passed, then it checks the permissions
someone with that role would have, which is essentially:
@ -680,6 +727,12 @@ class GuildChannel:
.. versionchanged:: 2.0
``obj`` parameter is now positional-only.
.. versionchanged:: 2.4
User installed apps are now taken into account.
The permissions returned for a user installed app mirrors the
permissions Discord returns in :attr:`~discord.Interaction.app_permissions`,
though it is recommended to use that attribute instead.
Parameters
----------
obj: Union[:class:`~discord.Member`, :class:`~discord.Role`]
@ -711,6 +764,13 @@ class GuildChannel:
return Permissions.all()
default = self.guild.default_role
if default is None:
if self._state.self_id == obj.id:
return Permissions._user_installed_permissions(in_guild=True)
else:
return Permissions.none()
base = Permissions(default.permissions.value)
# Handle the role case first
@ -935,8 +995,6 @@ class GuildChannel:
if len(permissions) > 0:
raise TypeError('Cannot mix overwrite and keyword arguments.')
# TODO: wait for event
if overwrite is None:
await http.delete_channel_permissions(self.id, target.id, reason=reason)
elif isinstance(overwrite, PermissionOverwrite):
@ -952,11 +1010,15 @@ class GuildChannel:
base_attrs: Dict[str, Any],
*,
name: Optional[str] = None,
category: Optional[CategoryChannel] = None,
reason: Optional[str] = None,
) -> Self:
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id
base_attrs['name'] = name or self.name
if category is not None:
base_attrs['parent_id'] = category.id
guild_id = self.guild.id
cls = self.__class__
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
@ -966,7 +1028,13 @@ class GuildChannel:
self.guild._channels[obj.id] = obj # type: ignore # obj is a GuildChannel
return obj
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> Self:
async def clone(
self,
*,
name: Optional[str] = None,
category: Optional[CategoryChannel] = None,
reason: Optional[str] = None,
) -> Self:
"""|coro|
Clones this channel. This creates a channel with the same properties
@ -981,6 +1049,11 @@ class GuildChannel:
name: Optional[:class:`str`]
The name of the new channel. If not provided, defaults to this
channel name.
category: Optional[:class:`~discord.CategoryChannel`]
The category the new channel belongs to.
This parameter is ignored if cloning a category channel.
.. versionadded:: 2.5
reason: Optional[:class:`str`]
The reason for cloning this channel. Shows up on the audit log.
@ -1077,10 +1150,10 @@ class GuildChannel:
channel list (or category if given).
This is mutually exclusive with ``beginning``, ``before``, and ``after``.
before: :class:`~discord.abc.Snowflake`
The channel that should be before our current channel.
Whether to move the channel before the given channel.
This is mutually exclusive with ``beginning``, ``end``, and ``after``.
after: :class:`~discord.abc.Snowflake`
The channel that should be after our current channel.
Whether to move the channel after the given channel.
This is mutually exclusive with ``beginning``, ``end``, and ``before``.
offset: :class:`int`
The number of channels to offset the move by. For example,
@ -1163,11 +1236,11 @@ class GuildChannel:
raise ValueError('Could not resolve appropriate move position')
channels.insert(max((index + offset), 0), self)
payload = []
payload: List[ChannelPositionUpdate] = []
lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason')
for index, channel in enumerate(channels):
d = {'id': channel.id, 'position': index}
d: ChannelPositionUpdate = {'id': channel.id, 'position': index}
if parent_id is not MISSING and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
@ -1185,6 +1258,7 @@ class GuildChannel:
target_type: Optional[InviteTarget] = None,
target_user: Optional[User] = None,
target_application_id: Optional[int] = None,
guest: bool = False,
) -> Invite:
"""|coro|
@ -1223,6 +1297,10 @@ class GuildChannel:
The id of the embedded application for the invite, required if ``target_type`` is :attr:`.InviteTarget.embedded_application`.
.. versionadded:: 2.0
guest: :class:`bool`
Whether the invite is a guest invite.
.. versionadded:: 2.6
Raises
-------
@ -1237,6 +1315,13 @@ class GuildChannel:
:class:`~discord.Invite`
The invite that was created.
"""
if target_type is InviteTarget.unknown:
raise ValueError('Cannot create invite with an unknown target type')
flags: Optional[InviteFlags] = None
if guest:
flags = InviteFlags._from_value(0)
flags.guest = True
data = await self._state.http.create_invite(
self.id,
@ -1248,6 +1333,7 @@ class GuildChannel:
target_type=target_type.value if target_type else None,
target_user_id=target_user.id if target_user else None,
target_application_id=target_application_id,
flags=flags.value if flags else None,
)
return Invite.from_incomplete(data=data, state=self._state)
@ -1284,6 +1370,7 @@ class Messageable:
- :class:`~discord.TextChannel`
- :class:`~discord.VoiceChannel`
- :class:`~discord.StageChannel`
- :class:`~discord.DMChannel`
- :class:`~discord.GroupChannel`
- :class:`~discord.PartialMessageable`
@ -1316,6 +1403,7 @@ class Messageable:
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@ -1336,6 +1424,7 @@ class Messageable:
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@ -1356,6 +1445,7 @@ class Messageable:
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@ -1376,6 +1466,7 @@ class Messageable:
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@ -1397,6 +1488,7 @@ class Messageable:
view: Optional[View] = None,
suppress_embeds: bool = False,
silent: bool = False,
poll: Optional[Poll] = None,
) -> Message:
"""|coro|
@ -1454,10 +1546,11 @@ class Messageable:
.. versionadded:: 1.4
reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`, :class:`~discord.PartialMessage`]
A reference to the :class:`~discord.Message` to which you are replying, this can be created using
:meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control
whether this mentions the author of the referenced message using the :attr:`~discord.AllowedMentions.replied_user`
attribute of ``allowed_mentions`` or by setting ``mention_author``.
A reference to the :class:`~discord.Message` to which you are referencing, this can be created using
:meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`.
In the event of a replying reference, you can control whether this mentions the author of the referenced
message using the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions`` or by
setting ``mention_author``.
.. versionadded:: 1.6
@ -1482,6 +1575,10 @@ class Messageable:
in the UI, but will not actually send a notification.
.. versionadded:: 2.2
poll: :class:`~discord.Poll`
The poll to send with this message.
.. versionadded:: 2.4
Raises
--------
@ -1489,6 +1586,9 @@ class Messageable:
Sending the message failed.
~discord.Forbidden
You do not have the proper permissions to send the message.
~discord.NotFound
You sent a message with the same nonce as one that has been explicitly
deleted shortly earlier.
ValueError
The ``files`` or ``embeds`` list is not of the appropriate size.
TypeError
@ -1533,6 +1633,9 @@ class Messageable:
else:
flags = MISSING
if nonce is None:
nonce = secrets.randbits(64)
with handle_message_parameters(
content=content,
tts=tts,
@ -1548,6 +1651,7 @@ class Messageable:
stickers=sticker_ids,
view=view,
flags=flags,
poll=poll,
) as params:
data = await state.http.send_message(channel.id, params=params)
@ -1555,6 +1659,9 @@ class Messageable:
if view and not view.is_finished():
state.store_view(view, ret.id)
if poll:
poll._update(ret)
if delete_after is not None:
await ret.delete(delay=delete_after)
return ret
@ -1713,12 +1820,12 @@ class Messageable:
async def _around_strategy(retrieve: int, around: Optional[Snowflake], limit: Optional[int]):
if not around:
return []
return [], None, 0
around_id = around.id if around else None
data = await self._state.http.logs_from(channel.id, retrieve, around=around_id)
return data, None, limit
return data, None, 0
async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]):
after_id = after.id if after else None
@ -1831,7 +1938,7 @@ class Connectable(Protocol):
async def connect(
self,
*,
timeout: float = 60.0,
timeout: float = 30.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = VoiceClient,
self_deaf: bool = False,
@ -1847,7 +1954,7 @@ class Connectable(Protocol):
Parameters
-----------
timeout: :class:`float`
The timeout in seconds to wait for the voice endpoint.
The timeout in seconds to wait the connection to complete.
reconnect: :class:`bool`
Whether the bot should automatically attempt
a reconnect if a part of the handshake fails

48
discord/activity.py

@ -162,6 +162,10 @@ class Activity(BaseActivity):
The user's current state. For example, "In Game".
details: Optional[:class:`str`]
The detail of the user's current activity.
platform: Optional[:class:`str`]
The user's current platform.
.. versionadded:: 2.4
timestamps: :class:`dict`
A dictionary of timestamps. It contains the following optional keys:
@ -197,6 +201,7 @@ class Activity(BaseActivity):
'state',
'details',
'timestamps',
'platform',
'assets',
'party',
'flags',
@ -215,6 +220,7 @@ class Activity(BaseActivity):
self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None)
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {})
self.platform: Optional[str] = kwargs.pop('platform', None)
self.assets: ActivityAssets = kwargs.pop('assets', {})
self.party: ActivityParty = kwargs.pop('party', {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id')
@ -238,6 +244,7 @@ class Activity(BaseActivity):
('type', self.type),
('name', self.name),
('url', self.url),
('platform', self.platform),
('details', self.details),
('application_id', self.application_id),
('session_id', self.session_id),
@ -266,7 +273,7 @@ class Activity(BaseActivity):
def start(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps['start'] / 1000
timestamp = self.timestamps['start'] / 1000 # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
return None
else:
@ -276,7 +283,7 @@ class Activity(BaseActivity):
def end(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps['end'] / 1000
timestamp = self.timestamps['end'] / 1000 # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
return None
else:
@ -286,7 +293,7 @@ class Activity(BaseActivity):
def large_image_url(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity, if applicable."""
try:
large_image = self.assets['large_image']
large_image = self.assets['large_image'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
return None
else:
@ -296,7 +303,7 @@ class Activity(BaseActivity):
def small_image_url(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity, if applicable."""
try:
small_image = self.assets['small_image']
small_image = self.assets['small_image'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
return None
else:
@ -351,13 +358,30 @@ class Game(BaseActivity):
-----------
name: :class:`str`
The game's name.
platform: Optional[:class:`str`]
Where the user is playing from (ie. PS5, Xbox).
.. versionadded:: 2.4
assets: :class:`dict`
A dictionary representing the images and their hover text of a game.
It contains the following optional keys:
- ``large_image``: A string representing the ID for the large image asset.
- ``large_text``: A string representing the text when hovering over the large image asset.
- ``small_image``: A string representing the ID for the small image asset.
- ``small_text``: A string representing the text when hovering over the small image asset.
.. versionadded:: 2.4
"""
__slots__ = ('name', '_end', '_start')
__slots__ = ('name', '_end', '_start', 'platform', 'assets')
def __init__(self, name: str, **extra: Any) -> None:
super().__init__(**extra)
self.name: str = name
self.platform: Optional[str] = extra.get('platform')
self.assets: ActivityAssets = extra.get('assets', {}) or {}
try:
timestamps: ActivityTimestamps = extra['timestamps']
@ -394,7 +418,7 @@ class Game(BaseActivity):
return str(self.name)
def __repr__(self) -> str:
return f'<Game name={self.name!r}>'
return f'<Game name={self.name!r} platform={self.platform!r}>'
def to_dict(self) -> Dict[str, Any]:
timestamps: Dict[str, Any] = {}
@ -408,6 +432,8 @@ class Game(BaseActivity):
'type': ActivityType.playing.value,
'name': str(self.name),
'timestamps': timestamps,
'platform': str(self.platform) if self.platform else None,
'assets': self.assets,
}
def __eq__(self, other: object) -> bool:
@ -488,7 +514,7 @@ class Streaming(BaseActivity):
return str(self.name)
def __repr__(self) -> str:
return f'<Streaming name={self.name!r}>'
return f'<Streaming name={self.name!r} platform={self.platform!r}>'
@property
def twitch_name(self) -> Optional[str]:
@ -499,7 +525,7 @@ class Streaming(BaseActivity):
"""
try:
name = self.assets['large_image']
name = self.assets['large_image'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
return None
else:
@ -732,10 +758,12 @@ class CustomActivity(BaseActivity):
__slots__ = ('name', 'emoji', 'state')
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any) -> None:
def __init__(
self, name: Optional[str], *, emoji: Optional[Union[PartialEmoji, Dict[str, Any], str]] = None, **extra: Any
) -> None:
super().__init__(**extra)
self.name: Optional[str] = name
self.state: Optional[str] = extra.pop('state', None)
self.state: Optional[str] = extra.pop('state', name)
if self.name == 'Custom Status':
self.name = self.state

1
discord/app_commands/__init__.py

@ -16,5 +16,6 @@ from .tree import *
from .namespace import *
from .transformers import *
from .translator import *
from .installs import *
from . import checks as checks
from .checks import Cooldown as Cooldown

2
discord/app_commands/checks.py

@ -186,7 +186,7 @@ class Cooldown:
:class:`Cooldown`
A new instance of this cooldown.
"""
return Cooldown(self.rate, self.per)
return self.__class__(self.rate, self.per)
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'

480
discord/app_commands/commands.py

@ -46,8 +46,10 @@ from typing import (
)
import re
from copy import copy as shallow_copy
from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale
from .installs import AppCommandContext, AppInstallationType
from .models import Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
@ -64,6 +66,8 @@ if TYPE_CHECKING:
from ..abc import Snowflake
from .namespace import Namespace
from .models import ChoiceT
from .tree import CommandTree
from .._types import ClientT
# Generally, these two libraries are supposed to be separate from each other.
# However, for type hinting purposes it's unfortunately necessary for one to
@ -86,6 +90,12 @@ __all__ = (
'autocomplete',
'guilds',
'guild_only',
'dm_only',
'private_channel_only',
'allowed_contexts',
'guild_install',
'user_install',
'allowed_installs',
'default_permissions',
)
@ -617,6 +627,16 @@ class Command(Generic[GroupT, P, T]):
Whether the command should only be usable in guild contexts.
Due to a Discord limitation, this does not work on subcommands.
allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`]
The contexts that the command is allowed to be used in.
Overrides ``guild_only`` if this is set.
.. versionadded:: 2.4
allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`]
The installation contexts that the command is allowed to be installed
on.
.. versionadded:: 2.4
nsfw: :class:`bool`
Whether the command is NSFW and should only work in NSFW channels.
@ -637,6 +657,8 @@ class Command(Generic[GroupT, P, T]):
nsfw: bool = False,
parent: Optional[Group] = None,
guild_ids: Optional[List[int]] = None,
allowed_contexts: Optional[AppCommandContext] = None,
allowed_installs: Optional[AppInstallationType] = None,
auto_locale_strings: bool = True,
extras: Dict[Any, Any] = MISSING,
):
@ -671,6 +693,13 @@ class Command(Generic[GroupT, P, T]):
callback, '__discord_app_commands_default_permissions__', None
)
self.guild_only: bool = getattr(callback, '__discord_app_commands_guild_only__', False)
self.allowed_contexts: Optional[AppCommandContext] = allowed_contexts or getattr(
callback, '__discord_app_commands_contexts__', None
)
self.allowed_installs: Optional[AppInstallationType] = allowed_installs or getattr(
callback, '__discord_app_commands_installation_types__', None
)
self.nsfw: bool = nsfw
self.extras: Dict[Any, Any] = extras or {}
@ -707,33 +736,18 @@ class Command(Generic[GroupT, P, T]):
) -> Command:
bindings = {} if bindings is MISSING else bindings
cls = self.__class__
copy = cls.__new__(cls)
copy.name = self.name
copy._locale_name = self._locale_name
copy._guild_ids = self._guild_ids
copy.checks = self.checks
copy.description = self.description
copy._locale_description = self._locale_description
copy.default_permissions = self.default_permissions
copy.guild_only = self.guild_only
copy.nsfw = self.nsfw
copy._attr = self._attr
copy._callback = self._callback
copy.on_error = self.on_error
copy = shallow_copy(self)
copy._params = self._params.copy()
copy.module = self.module
copy.parent = parent
copy.binding = bindings.get(self.binding) if self.binding is not None else binding
copy.extras = self.extras
if copy._attr and set_on_binding:
setattr(copy.binding, copy._attr, copy)
return copy
async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]:
base = self.to_dict()
async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]:
base = self.to_dict(tree)
name_localizations: Dict[str, str] = {}
description_localizations: Dict[str, str] = {}
@ -759,7 +773,7 @@ class Command(Generic[GroupT, P, T]):
]
return base
def to_dict(self) -> Dict[str, Any]:
def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]:
# If we have a parent then our type is a subcommand
# Otherwise, the type falls back to the specific command type (e.g. slash command or context menu)
option_type = AppCommandType.chat_input.value if self.parent is None else AppCommandOptionType.subcommand.value
@ -774,6 +788,8 @@ class Command(Generic[GroupT, P, T]):
base['nsfw'] = self.nsfw
base['dm_permission'] = not self.guild_only
base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value
base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts)
base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs)
return base
@ -887,7 +903,7 @@ class Command(Generic[GroupT, P, T]):
predicates = getattr(param.autocomplete, '__discord_app_commands_checks__', [])
if predicates:
try:
passed = await async_all(f(interaction) for f in predicates)
passed = await async_all(f(interaction) for f in predicates) # type: ignore
except Exception:
passed = False
@ -990,7 +1006,7 @@ class Command(Generic[GroupT, P, T]):
if self.binding is not None:
check: Optional[Check] = getattr(self.binding, 'interaction_check', None)
if check:
ret = await maybe_coroutine(check, interaction) # type: ignore # Probable pyright bug
ret = await maybe_coroutine(check, interaction)
if not ret:
return False
@ -998,7 +1014,7 @@ class Command(Generic[GroupT, P, T]):
if not predicates:
return True
return await async_all(f(interaction) for f in predicates)
return await async_all(f(interaction) for f in predicates) # type: ignore
def error(self, coro: Error[GroupT]) -> Error[GroupT]:
"""A decorator that registers a coroutine as a local error handler.
@ -1082,7 +1098,7 @@ class Command(Generic[GroupT, P, T]):
def decorator(coro: AutocompleteCallback[GroupT, ChoiceT]) -> AutocompleteCallback[GroupT, ChoiceT]:
if not inspect.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
raise TypeError('The autocomplete callback must be a coroutine function.')
try:
param = self._params[name]
@ -1181,6 +1197,16 @@ class ContextMenu:
guild_only: :class:`bool`
Whether the command should only be usable in guild contexts.
Defaults to ``False``.
allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`]
The contexts that this context menu is allowed to be used in.
Overrides ``guild_only`` if set.
.. versionadded:: 2.4
allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`]
The installation contexts that the command is allowed to be installed
on.
.. versionadded:: 2.4
nsfw: :class:`bool`
Whether the command is NSFW and should only work in NSFW channels.
Defaults to ``False``.
@ -1203,6 +1229,8 @@ class ContextMenu:
type: AppCommandType = MISSING,
nsfw: bool = False,
guild_ids: Optional[List[int]] = None,
allowed_contexts: Optional[AppCommandContext] = None,
allowed_installs: Optional[AppInstallationType] = None,
auto_locale_strings: bool = True,
extras: Dict[Any, Any] = MISSING,
):
@ -1228,6 +1256,12 @@ class ContextMenu:
)
self.nsfw: bool = nsfw
self.guild_only: bool = getattr(callback, '__discord_app_commands_guild_only__', False)
self.allowed_contexts: Optional[AppCommandContext] = allowed_contexts or getattr(
callback, '__discord_app_commands_contexts__', None
)
self.allowed_installs: Optional[AppInstallationType] = allowed_installs or getattr(
callback, '__discord_app_commands_installation_types__', None
)
self.checks: List[Check] = getattr(callback, '__discord_app_commands_checks__', [])
self.extras: Dict[Any, Any] = extras or {}
@ -1245,8 +1279,8 @@ class ContextMenu:
""":class:`str`: Returns the fully qualified command name."""
return self.name
async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]:
base = self.to_dict()
async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]:
base = self.to_dict(tree)
context = TranslationContext(location=TranslationContextLocation.command_name, data=self)
if self._locale_name:
name_localizations: Dict[str, str] = {}
@ -1258,11 +1292,13 @@ class ContextMenu:
base['name_localizations'] = name_localizations
return base
def to_dict(self) -> Dict[str, Any]:
def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]:
return {
'name': self.name,
'type': self.type.value,
'dm_permission': not self.guild_only,
'contexts': tree.allowed_contexts._merge_to_array(self.allowed_contexts),
'integration_types': tree.allowed_installs._merge_to_array(self.allowed_installs),
'default_member_permissions': None if self.default_permissions is None else self.default_permissions.value,
'nsfw': self.nsfw,
}
@ -1272,7 +1308,7 @@ class ContextMenu:
if not predicates:
return True
return await async_all(f(interaction) for f in predicates)
return await async_all(f(interaction) for f in predicates) # type: ignore
def _has_any_error_handlers(self) -> bool:
return self.on_error is not None
@ -1419,6 +1455,16 @@ class Group:
Whether the group should only be usable in guild contexts.
Due to a Discord limitation, this does not work on subcommands.
allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`]
The contexts that this group is allowed to be used in. Overrides
guild_only if set.
.. versionadded:: 2.4
allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`]
The installation contexts that the command is allowed to be installed
on.
.. versionadded:: 2.4
nsfw: :class:`bool`
Whether the command is NSFW and should only work in NSFW channels.
@ -1438,6 +1484,8 @@ class Group:
__discord_app_commands_group_locale_description__: Optional[locale_str] = None
__discord_app_commands_group_nsfw__: bool = False
__discord_app_commands_guild_only__: bool = MISSING
__discord_app_commands_contexts__: Optional[AppCommandContext] = MISSING
__discord_app_commands_installation_types__: Optional[AppInstallationType] = MISSING
__discord_app_commands_default_permissions__: Optional[Permissions] = MISSING
__discord_app_commands_has_module__: bool = False
__discord_app_commands_error_handler__: Optional[
@ -1506,6 +1554,8 @@ class Group:
parent: Optional[Group] = None,
guild_ids: Optional[List[int]] = None,
guild_only: bool = MISSING,
allowed_contexts: Optional[AppCommandContext] = MISSING,
allowed_installs: Optional[AppInstallationType] = MISSING,
nsfw: bool = MISSING,
auto_locale_strings: bool = True,
default_permissions: Optional[Permissions] = MISSING,
@ -1554,6 +1604,22 @@ class Group:
self.guild_only: bool = guild_only
if allowed_contexts is MISSING:
if cls.__discord_app_commands_contexts__ is MISSING:
allowed_contexts = None
else:
allowed_contexts = cls.__discord_app_commands_contexts__
self.allowed_contexts: Optional[AppCommandContext] = allowed_contexts
if allowed_installs is MISSING:
if cls.__discord_app_commands_installation_types__ is MISSING:
allowed_installs = None
else:
allowed_installs = cls.__discord_app_commands_installation_types__
self.allowed_installs: Optional[AppInstallationType] = allowed_installs
if nsfw is MISSING:
nsfw = cls.__discord_app_commands_group_nsfw__
@ -1562,6 +1628,9 @@ class Group:
if not self.description:
raise TypeError('groups must have a description')
if not self.name:
raise TypeError('groups must have a name')
self.parent: Optional[Group] = parent
self.module: Optional[str]
if cls.__discord_app_commands_has_module__:
@ -1622,22 +1691,9 @@ class Group:
) -> Group:
bindings = {} if bindings is MISSING else bindings
cls = self.__class__
copy = cls.__new__(cls)
copy.name = self.name
copy._locale_name = self._locale_name
copy._guild_ids = self._guild_ids
copy.description = self.description
copy._locale_description = self._locale_description
copy = shallow_copy(self)
copy.parent = parent
copy.module = self.module
copy.default_permissions = self.default_permissions
copy.guild_only = self.guild_only
copy.nsfw = self.nsfw
copy._attr = self._attr
copy._owner_cls = self._owner_cls
copy._children = {}
copy.extras = self.extras
bindings[self] = copy
@ -1657,8 +1713,8 @@ class Group:
return copy
async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]:
base = self.to_dict()
async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]:
base = self.to_dict(tree)
name_localizations: Dict[str, str] = {}
description_localizations: Dict[str, str] = {}
@ -1678,10 +1734,10 @@ class Group:
base['name_localizations'] = name_localizations
base['description_localizations'] = description_localizations
base['options'] = [await child.get_translated_payload(translator) for child in self._children.values()]
base['options'] = [await child.get_translated_payload(tree, translator) for child in self._children.values()]
return base
def to_dict(self) -> Dict[str, Any]:
def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]:
# If this has a parent command then it's part of a subcommand group
# Otherwise, it's just a regular command
option_type = 1 if self.parent is None else AppCommandOptionType.subcommand_group.value
@ -1689,13 +1745,15 @@ class Group:
'name': self.name,
'description': self.description,
'type': option_type,
'options': [child.to_dict() for child in self._children.values()],
'options': [child.to_dict(tree) for child in self._children.values()],
}
if self.parent is None:
base['nsfw'] = self.nsfw
base['dm_permission'] = not self.guild_only
base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value
base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts)
base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs)
return base
@ -1784,7 +1842,7 @@ class Group:
if len(params) != 2:
raise TypeError('The error handler must have 2 parameters.')
self.on_error = coro
self.on_error = coro # type: ignore
return coro
async def interaction_check(self, interaction: Interaction, /) -> bool:
@ -2314,6 +2372,12 @@ def guilds(*guild_ids: Union[Snowflake, int]) -> Callable[[T], T]:
with the :meth:`CommandTree.command` or :meth:`CommandTree.context_menu` decorator
then this must go below that decorator.
.. note ::
Due to a Discord limitation, this decorator cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
Example:
.. code-block:: python3
@ -2445,8 +2509,70 @@ def guild_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = True
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
f.__discord_app_commands_guild_only__ = True # type: ignore # Runtime attribute assignment
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment
allowed_contexts.guild = True
return f
# Check if called with parentheses or not
if func is None:
# Called with parentheses
return inner
else:
return inner(func)
@overload
def private_channel_only(func: None = ...) -> Callable[[T], T]:
...
@overload
def private_channel_only(func: T) -> T:
...
def private_channel_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
"""A decorator that indicates this command can only be used in the context of DMs and group DMs.
This is **not** implemented as a :func:`check`, and is instead verified by Discord server side.
Therefore, there is no error handler called when a command is used within a guild.
This decorator can be called with or without parentheses.
Due to a Discord limitation, this decorator does nothing in subcommands and is ignored.
.. versionadded:: 2.4
Examples
---------
.. code-block:: python3
@app_commands.command()
@app_commands.private_channel_only()
async def my_private_channel_only_command(interaction: discord.Interaction) -> None:
await interaction.response.send_message('I am only available in DMs and GDMs!')
"""
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = False
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment
allowed_contexts.private_channel = True
return f
# Check if called with parentheses or not
@ -2457,7 +2583,245 @@ def guild_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
return inner(func)
def default_permissions(**perms: bool) -> Callable[[T], T]:
@overload
def dm_only(func: None = ...) -> Callable[[T], T]:
...
@overload
def dm_only(func: T) -> T:
...
def dm_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
"""A decorator that indicates this command can only be used in the context of bot DMs.
This is **not** implemented as a :func:`check`, and is instead verified by Discord server side.
Therefore, there is no error handler called when a command is used within a guild or group DM.
This decorator can be called with or without parentheses.
Due to a Discord limitation, this decorator does nothing in subcommands and is ignored.
Examples
---------
.. code-block:: python3
@app_commands.command()
@app_commands.dm_only()
async def my_dm_only_command(interaction: discord.Interaction) -> None:
await interaction.response.send_message('I am only available in DMs!')
"""
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = False
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment
allowed_contexts.dm_channel = True
return f
# Check if called with parentheses or not
if func is None:
# Called with parentheses
return inner
else:
return inner(func)
def allowed_contexts(guilds: bool = MISSING, dms: bool = MISSING, private_channels: bool = MISSING) -> Callable[[T], T]:
"""A decorator that indicates this command can only be used in certain contexts.
Valid contexts are guilds, DMs and private channels.
This is **not** implemented as a :func:`check`, and is instead verified by Discord server side.
Due to a Discord limitation, this decorator does nothing in subcommands and is ignored.
.. versionadded:: 2.4
Examples
---------
.. code-block:: python3
@app_commands.command()
@app_commands.allowed_contexts(guilds=True, dms=False, private_channels=True)
async def my_command(interaction: discord.Interaction) -> None:
await interaction.response.send_message('I am only available in guilds and private channels!')
"""
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = False
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment
if guilds is not MISSING:
allowed_contexts.guild = guilds
if dms is not MISSING:
allowed_contexts.dm_channel = dms
if private_channels is not MISSING:
allowed_contexts.private_channel = private_channels
return f
return inner
@overload
def guild_install(func: None = ...) -> Callable[[T], T]:
...
@overload
def guild_install(func: T) -> T:
...
def guild_install(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
"""A decorator that indicates this command should be installed in guilds.
This is **not** implemented as a :func:`check`, and is instead verified by Discord server side.
Due to a Discord limitation, this decorator does nothing in subcommands and is ignored.
.. versionadded:: 2.4
Examples
---------
.. code-block:: python3
@app_commands.command()
@app_commands.guild_install()
async def my_guild_install_command(interaction: discord.Interaction) -> None:
await interaction.response.send_message('I am installed in guilds by default!')
"""
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
allowed_installs = f.allowed_installs or AppInstallationType()
f.allowed_installs = allowed_installs
else:
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType()
f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment
allowed_installs.guild = True
return f
# Check if called with parentheses or not
if func is None:
# Called with parentheses
return inner
else:
return inner(func)
@overload
def user_install(func: None = ...) -> Callable[[T], T]:
...
@overload
def user_install(func: T) -> T:
...
def user_install(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
"""A decorator that indicates this command should be installed for users.
This is **not** implemented as a :func:`check`, and is instead verified by Discord server side.
Due to a Discord limitation, this decorator does nothing in subcommands and is ignored.
.. versionadded:: 2.4
Examples
---------
.. code-block:: python3
@app_commands.command()
@app_commands.user_install()
async def my_user_install_command(interaction: discord.Interaction) -> None:
await interaction.response.send_message('I am installed in users by default!')
"""
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
allowed_installs = f.allowed_installs or AppInstallationType()
f.allowed_installs = allowed_installs
else:
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType()
f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment
allowed_installs.user = True
return f
# Check if called with parentheses or not
if func is None:
# Called with parentheses
return inner
else:
return inner(func)
def allowed_installs(
guilds: bool = MISSING,
users: bool = MISSING,
) -> Callable[[T], T]:
"""A decorator that indicates this command should be installed in certain contexts.
Valid contexts are guilds and users.
This is **not** implemented as a :func:`check`, and is instead verified by Discord server side.
Due to a Discord limitation, this decorator does nothing in subcommands and is ignored.
.. versionadded:: 2.4
Examples
---------
.. code-block:: python3
@app_commands.command()
@app_commands.allowed_installs(guilds=False, users=True)
async def my_command(interaction: discord.Interaction) -> None:
await interaction.response.send_message('I am installed in users by default!')
"""
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
allowed_installs = f.allowed_installs or AppInstallationType()
f.allowed_installs = allowed_installs
else:
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType()
f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment
if guilds is not MISSING:
allowed_installs.guild = guilds
if users is not MISSING:
allowed_installs.user = users
return f
return inner
def default_permissions(perms_obj: Optional[Permissions] = None, /, **perms: bool) -> Callable[[T], T]:
r"""A decorator that sets the default permissions needed to execute this command.
When this decorator is used, by default users must have these permissions to execute the command.
@ -2481,8 +2845,12 @@ def default_permissions(**perms: bool) -> Callable[[T], T]:
-----------
\*\*perms: :class:`bool`
Keyword arguments denoting the permissions to set as the default.
perms_obj: :class:`~discord.Permissions`
A permissions object as positional argument. This can be used in combination with ``**perms``.
Example
.. versionadded:: 2.5
Examples
---------
.. code-block:: python3
@ -2491,9 +2859,21 @@ def default_permissions(**perms: bool) -> Callable[[T], T]:
@app_commands.default_permissions(manage_messages=True)
async def test(interaction: discord.Interaction):
await interaction.response.send_message('You may or may not have manage messages.')
.. code-block:: python3
ADMIN_PERMS = discord.Permissions(administrator=True)
@app_commands.command()
@app_commands.default_permissions(ADMIN_PERMS, manage_messages=True)
async def test(interaction: discord.Interaction):
await interaction.response.send_message('You may or may not have manage messages.')
"""
permissions = Permissions(**perms)
if perms_obj is not None:
permissions = perms_obj | Permissions(**perms)
else:
permissions = Permissions(**perms)
def decorator(func: T) -> T:
if isinstance(func, (Command, Group, ContextMenu)):

64
discord/app_commands/errors.py

@ -27,7 +27,8 @@ from __future__ import annotations
from typing import Any, TYPE_CHECKING, List, Optional, Sequence, Union
from ..enums import AppCommandOptionType, AppCommandType, Locale
from ..errors import DiscordException, HTTPException, _flatten_error_dict
from ..errors import DiscordException, HTTPException, _flatten_error_dict, MissingApplicationID as MissingApplicationID
from ..utils import _human_join
__all__ = (
'AppCommandError',
@ -58,11 +59,6 @@ if TYPE_CHECKING:
CommandTypes = Union[Command[Any, ..., Any], Group, ContextMenu]
APP_ID_NOT_FOUND = (
'Client does not have an application_id set. Either the function was called before on_ready '
'was called or application_id was not passed to the Client constructor.'
)
class AppCommandError(DiscordException):
"""The base exception type for all application command related errors.
@ -242,13 +238,7 @@ class MissingAnyRole(CheckFailure):
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = _human_join([f"'{role}'" for role in missing_roles])
message = f'You are missing at least one of the required roles: {fmt}'
super().__init__(message)
@ -271,11 +261,7 @@ class MissingPermissions(CheckFailure):
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
fmt = _human_join(missing, final='and')
message = f'You are missing {fmt} permission(s) to run this command.'
super().__init__(message, *args)
@ -298,11 +284,7 @@ class BotMissingPermissions(CheckFailure):
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
fmt = _human_join(missing, final='and')
message = f'Bot requires {fmt} permission(s) to run this command.'
super().__init__(message, *args)
@ -435,19 +417,6 @@ class CommandSignatureMismatch(AppCommandError):
super().__init__(msg)
class MissingApplicationID(AppCommandError):
"""An exception raised when the client does not have an application ID set.
An application ID is required for syncing application commands.
This inherits from :exc:`~discord.app_commands.AppCommandError`.
.. versionadded:: 2.0
"""
def __init__(self, message: Optional[str] = None):
super().__init__(message or APP_ID_NOT_FOUND)
def _get_command_error(
index: str,
inner: Any,
@ -498,6 +467,10 @@ def _get_command_error(
if key == 'options':
for index, d in remaining.items():
_get_command_error(index, d, children, messages, indent=indent + 2)
elif key == '_errors':
errors = [x.get('message', '') for x in remaining]
messages.extend(f'{indentation} {message}' for message in errors)
else:
if isinstance(remaining, dict):
try:
@ -506,10 +479,9 @@ def _get_command_error(
errors = _flatten_error_dict(remaining, key=key)
else:
errors = {key: ' '.join(x.get('message', '') for x in inner_errors)}
else:
errors = _flatten_error_dict(remaining, key=key)
messages.extend(f'{indentation} {k}: {v}' for k, v in errors.items())
if isinstance(errors, dict):
messages.extend(f'{indentation} {k}: {v}' for k, v in errors.items())
class CommandSyncFailure(AppCommandError, HTTPException):
@ -530,8 +502,18 @@ class CommandSyncFailure(AppCommandError, HTTPException):
messages = [f'Failed to upload commands to Discord (HTTP status {self.status}, error code {self.code})']
if self._errors:
for index, inner in self._errors.items():
_get_command_error(index, inner, commands, messages)
# Handle case where the errors dict has no actual chain such as APPLICATION_COMMAND_TOO_LARGE
if len(self._errors) == 1 and '_errors' in self._errors:
errors = self._errors['_errors']
if len(errors) == 1:
extra = errors[0].get('message')
if extra:
messages[0] += f': {extra}'
else:
messages.extend(f'Error {e.get("code", "")}: {e.get("message", "")}' for e in errors)
else:
for index, inner in self._errors.items():
_get_command_error(index, inner, commands, messages)
# Equivalent to super().__init__(...) but skips other constructors
self.args = ('\n'.join(messages),)

213
discord/app_commands/installs.py

@ -0,0 +1,213 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar, List, Optional, Sequence
__all__ = (
'AppInstallationType',
'AppCommandContext',
)
if TYPE_CHECKING:
from typing_extensions import Self
from ..types.interactions import InteractionContextType, InteractionInstallationType
class AppInstallationType:
r"""Represents the installation location of an application command.
.. versionadded:: 2.4
Parameters
-----------
guild: Optional[:class:`bool`]
Whether the integration is a guild install.
user: Optional[:class:`bool`]
Whether the integration is a user install.
"""
__slots__ = ('_guild', '_user')
GUILD: ClassVar[int] = 0
USER: ClassVar[int] = 1
def __init__(self, *, guild: Optional[bool] = None, user: Optional[bool] = None):
self._guild: Optional[bool] = guild
self._user: Optional[bool] = user
def __repr__(self):
return f'<AppInstallationType guild={self.guild!r} user={self.user!r}>'
@property
def guild(self) -> bool:
""":class:`bool`: Whether the integration is a guild install."""
return bool(self._guild)
@guild.setter
def guild(self, value: bool) -> None:
self._guild = bool(value)
@property
def user(self) -> bool:
""":class:`bool`: Whether the integration is a user install."""
return bool(self._user)
@user.setter
def user(self, value: bool) -> None:
self._user = bool(value)
def merge(self, other: AppInstallationType) -> AppInstallationType:
# Merging is similar to AllowedMentions where `self` is the base
# and the `other` is the override preference
guild = self._guild if other._guild is None else other._guild
user = self._user if other._user is None else other._user
return AppInstallationType(guild=guild, user=user)
def _is_unset(self) -> bool:
return all(x is None for x in (self._guild, self._user))
def _merge_to_array(self, other: Optional[AppInstallationType]) -> Optional[List[InteractionInstallationType]]:
result = self.merge(other) if other is not None else self
if result._is_unset():
return None
return result.to_array()
@classmethod
def _from_value(cls, value: Sequence[InteractionInstallationType]) -> Self:
self = cls()
for x in value:
if x == cls.GUILD:
self._guild = True
elif x == cls.USER:
self._user = True
return self
def to_array(self) -> List[InteractionInstallationType]:
values = []
if self._guild:
values.append(self.GUILD)
if self._user:
values.append(self.USER)
return values
class AppCommandContext:
r"""Wraps up the Discord :class:`~discord.app_commands.Command` execution context.
.. versionadded:: 2.4
Parameters
-----------
guild: Optional[:class:`bool`]
Whether the context allows usage in a guild.
dm_channel: Optional[:class:`bool`]
Whether the context allows usage in a DM channel.
private_channel: Optional[:class:`bool`]
Whether the context allows usage in a DM or a GDM channel.
"""
GUILD: ClassVar[int] = 0
DM_CHANNEL: ClassVar[int] = 1
PRIVATE_CHANNEL: ClassVar[int] = 2
__slots__ = ('_guild', '_dm_channel', '_private_channel')
def __init__(
self,
*,
guild: Optional[bool] = None,
dm_channel: Optional[bool] = None,
private_channel: Optional[bool] = None,
):
self._guild: Optional[bool] = guild
self._dm_channel: Optional[bool] = dm_channel
self._private_channel: Optional[bool] = private_channel
def __repr__(self) -> str:
return f'<AppCommandContext guild={self.guild!r} dm_channel={self.dm_channel!r} private_channel={self.private_channel!r}>'
@property
def guild(self) -> bool:
""":class:`bool`: Whether the context allows usage in a guild."""
return bool(self._guild)
@guild.setter
def guild(self, value: bool) -> None:
self._guild = bool(value)
@property
def dm_channel(self) -> bool:
""":class:`bool`: Whether the context allows usage in a DM channel."""
return bool(self._dm_channel)
@dm_channel.setter
def dm_channel(self, value: bool) -> None:
self._dm_channel = bool(value)
@property
def private_channel(self) -> bool:
""":class:`bool`: Whether the context allows usage in a DM or a GDM channel."""
return bool(self._private_channel)
@private_channel.setter
def private_channel(self, value: bool) -> None:
self._private_channel = bool(value)
def merge(self, other: AppCommandContext) -> AppCommandContext:
guild = self._guild if other._guild is None else other._guild
dm_channel = self._dm_channel if other._dm_channel is None else other._dm_channel
private_channel = self._private_channel if other._private_channel is None else other._private_channel
return AppCommandContext(guild=guild, dm_channel=dm_channel, private_channel=private_channel)
def _is_unset(self) -> bool:
return all(x is None for x in (self._guild, self._dm_channel, self._private_channel))
def _merge_to_array(self, other: Optional[AppCommandContext]) -> Optional[List[InteractionContextType]]:
result = self.merge(other) if other is not None else self
if result._is_unset():
return None
return result.to_array()
@classmethod
def _from_value(cls, value: Sequence[InteractionContextType]) -> Self:
self = cls()
for x in value:
if x == cls.GUILD:
self._guild = True
elif x == cls.DM_CHANNEL:
self._dm_channel = True
elif x == cls.PRIVATE_CHANNEL:
self._private_channel = True
return self
def to_array(self) -> List[InteractionContextType]:
values = []
if self._guild:
values.append(self.GUILD)
if self._dm_channel:
values.append(self.DM_CHANNEL)
if self._private_channel:
values.append(self.PRIVATE_CHANNEL)
return values

110
discord/app_commands/models.py

@ -26,9 +26,17 @@ from __future__ import annotations
from datetime import datetime
from .errors import MissingApplicationID
from ..flags import AppCommandContext, AppInstallationType, ChannelFlags
from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator
from ..permissions import Permissions
from ..enums import AppCommandOptionType, AppCommandType, AppCommandPermissionType, ChannelType, Locale, try_enum
from ..enums import (
AppCommandOptionType,
AppCommandType,
AppCommandPermissionType,
ChannelType,
Locale,
try_enum,
)
from ..mixins import Hashable
from ..utils import _get_as_snowflake, parse_time, snowflake_time, MISSING
from ..object import Object
@ -160,6 +168,14 @@ class AppCommand(Hashable):
The default member permissions that can run this command.
dm_permission: :class:`bool`
A boolean that indicates whether this command can be run in direct messages.
allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`]
The contexts that this command is allowed to be used in. Overrides the ``dm_permission`` attribute.
.. versionadded:: 2.4
allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`]
The installation contexts that this command is allowed to be installed in.
.. versionadded:: 2.4
guild_id: Optional[:class:`int`]
The ID of the guild this command is registered in. A value of ``None``
denotes that it is a global command.
@ -179,6 +195,8 @@ class AppCommand(Hashable):
'options',
'default_member_permissions',
'dm_permission',
'allowed_contexts',
'allowed_installs',
'nsfw',
'_state',
)
@ -210,6 +228,19 @@ class AppCommand(Hashable):
dm_permission = True
self.dm_permission: bool = dm_permission
allowed_contexts = data.get('contexts')
if allowed_contexts is None:
self.allowed_contexts: Optional[AppCommandContext] = None
else:
self.allowed_contexts = AppCommandContext._from_value(allowed_contexts)
allowed_installs = data.get('integration_types')
if allowed_installs is None:
self.allowed_installs: Optional[AppInstallationType] = None
else:
self.allowed_installs = AppInstallationType._from_value(allowed_installs)
self.nsfw: bool = data.get('nsfw', False)
self.name_localizations: Dict[Locale, str] = _to_locale_dict(data.get('name_localizations') or {})
self.description_localizations: Dict[Locale, str] = _to_locale_dict(data.get('description_localizations') or {})
@ -223,6 +254,8 @@ class AppCommand(Hashable):
'description': self.description,
'name_localizations': {str(k): v for k, v in self.name_localizations.items()},
'description_localizations': {str(k): v for k, v in self.description_localizations.items()},
'contexts': self.allowed_contexts.to_array() if self.allowed_contexts is not None else None,
'integration_types': self.allowed_installs.to_array() if self.allowed_installs is not None else None,
'options': [opt.to_dict() for opt in self.options],
} # type: ignore # Type checker does not understand this literal.
@ -542,6 +575,35 @@ class AppCommandChannel(Hashable):
the application command in that channel.
guild_id: :class:`int`
The guild ID this channel belongs to.
category_id: Optional[:class:`int`]
The category channel ID this channel belongs to, if applicable.
.. versionadded:: 2.6
topic: Optional[:class:`str`]
The channel's topic. ``None`` if it doesn't exist.
.. versionadded:: 2.6
position: :class:`int`
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
.. versionadded:: 2.6
last_message_id: Optional[:class:`int`]
The last message ID of the message sent to this channel. It may
*not* point to an existing or valid message.
.. versionadded:: 2.6
slowmode_delay: :class:`int`
The number of seconds a member must wait between sending messages
in this channel. A value of ``0`` denotes that it is disabled.
Bots and users with :attr:`~discord.Permissions.manage_channels` or
:attr:`~discord.Permissions.manage_messages` bypass slowmode.
.. versionadded:: 2.6
nsfw: :class:`bool`
If the channel is marked as "not safe for work" or "age restricted".
.. versionadded:: 2.6
"""
__slots__ = (
@ -550,6 +612,14 @@ class AppCommandChannel(Hashable):
'name',
'permissions',
'guild_id',
'topic',
'nsfw',
'position',
'category_id',
'slowmode_delay',
'last_message_id',
'_last_pin',
'_flags',
'_state',
)
@ -566,6 +636,14 @@ class AppCommandChannel(Hashable):
self.type: ChannelType = try_enum(ChannelType, data['type'])
self.name: str = data['name']
self.permissions: Permissions = Permissions(int(data['permissions']))
self.topic: Optional[str] = data.get('topic')
self.position: int = data.get('position') or 0
self.nsfw: bool = data.get('nsfw') or False
self.category_id: Optional[int] = _get_as_snowflake(data, 'parent_id')
self.slowmode_delay: int = data.get('rate_limit_per_user') or 0
self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
self._last_pin: Optional[datetime] = parse_time(data.get('last_pin_timestamp'))
self._flags: int = data.get('flags', 0)
def __str__(self) -> str:
return self.name
@ -578,6 +656,28 @@ class AppCommandChannel(Hashable):
"""Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found."""
return self._state._get_guild(self.guild_id)
@property
def flags(self) -> ChannelFlags:
""":class:`~discord.ChannelFlags`: The flags associated with this channel object.
.. versionadded:: 2.6
"""
return ChannelFlags._from_value(self._flags)
def is_nsfw(self) -> bool:
""":class:`bool`: Checks if the channel is NSFW.
.. versionadded:: 2.6
"""
return self.nsfw
def is_news(self) -> bool:
""":class:`bool`: Checks if the channel is a news channel.
.. versionadded:: 2.6
"""
return self.type == ChannelType.news
def resolve(self) -> Optional[GuildChannel]:
"""Resolves the application command channel to the appropriate channel
from cache if found.
@ -673,7 +773,7 @@ class AppCommandThread(Hashable):
archiver_id: Optional[:class:`int`]
The user's ID that archived this thread.
auto_archive_duration: :class:`int`
The duration in minutes until the thread is automatically archived due to inactivity.
The duration in minutes until the thread is automatically hidden from the channel list.
Usually a value of 60, 1440, 4320 and 10080.
archive_timestamp: :class:`datetime.datetime`
An aware timestamp of when the thread's archived status was last updated in UTC.
@ -1030,6 +1130,9 @@ class AppCommandPermissions:
self.target: Union[Object, User, Member, Role, AllChannels, GuildChannel] = _object
def __repr__(self) -> str:
return f'<AppCommandPermissions id={self.id} type={self.type!r} guild={self.guild!r} permission={self.permission}>'
def to_dict(self) -> ApplicationCommandPermissions:
return {
'id': self.target.id,
@ -1073,6 +1176,9 @@ class GuildAppCommandPermissions:
AppCommandPermissions(data=value, guild=guild, state=self._state) for value in data['permissions']
]
def __repr__(self) -> str:
return f'<GuildAppCommandPermissions id={self.id!r} guild_id={self.guild_id!r} permissions={self.permissions!r}>'
def to_dict(self) -> Dict[str, Any]:
return {'permissions': [p.to_dict() for p in self.permissions]}

4
discord/app_commands/namespace.py

@ -179,7 +179,7 @@ class Namespace:
state = interaction._state
members = resolved.get('members', {})
guild_id = interaction.guild_id
guild = state._get_or_create_unavailable_guild(guild_id) if guild_id is not None else None
guild = interaction.guild
type = AppCommandOptionType.user.value
for (user_id, user_data) in resolved.get('users', {}).items():
try:
@ -220,7 +220,6 @@ class Namespace:
}
)
guild = state._get_guild(guild_id)
for (message_id, message_data) in resolved.get('messages', {}).items():
channel_id = int(message_data['channel_id'])
if guild is None:
@ -232,6 +231,7 @@ class Namespace:
# Type checker doesn't understand this due to failure to narrow
message = Message(state=state, channel=channel, data=message_data) # type: ignore
message.guild = guild
key = ResolveKey(id=message_id, type=-1)
completed[key] = message

59
discord/app_commands/transformers.py

@ -34,6 +34,7 @@ from typing import (
ClassVar,
Coroutine,
Dict,
Generic,
List,
Literal,
Optional,
@ -51,11 +52,12 @@ from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel,
from ..abc import GuildChannel
from ..threads import Thread
from ..enums import Enum as InternalEnum, AppCommandOptionType, ChannelType, Locale
from ..utils import MISSING, maybe_coroutine
from ..utils import MISSING, maybe_coroutine, _human_join
from ..user import User
from ..role import Role
from ..member import Member
from ..message import Attachment
from .._types import ClientT
__all__ = (
'Transformer',
@ -177,8 +179,7 @@ class CommandParameter:
return choice
try:
# ParamSpec doesn't understand that transform is a callable since it's unbound
return await maybe_coroutine(self._annotation.transform, interaction, value) # type: ignore
return await maybe_coroutine(self._annotation.transform, interaction, value)
except AppCommandError:
raise
except Exception as e:
@ -192,7 +193,7 @@ class CommandParameter:
return self.name if self._rename is MISSING else str(self._rename)
class Transformer:
class Transformer(Generic[ClientT]):
"""The base class that allows a type annotation in an application command parameter
to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one
from this type.
@ -234,7 +235,7 @@ class Transformer:
pass
def __or__(self, rhs: Any) -> Any:
return Union[self, rhs] # type: ignore
return Union[self, rhs]
@property
def type(self) -> AppCommandOptionType:
@ -305,7 +306,7 @@ class Transformer:
else:
return name
async def transform(self, interaction: Interaction, value: Any, /) -> Any:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
"""|maybecoro|
Transforms the converted option value into another value.
@ -325,7 +326,7 @@ class Transformer:
raise NotImplementedError('Derived classes need to implement this.')
async def autocomplete(
self, interaction: Interaction, value: Union[int, float, str], /
self, interaction: Interaction[ClientT], value: Union[int, float, str], /
) -> List[Choice[Union[int, float, str]]]:
"""|coro|
@ -353,7 +354,7 @@ class Transformer:
raise NotImplementedError('Derived classes can implement this.')
class IdentityTransformer(Transformer):
class IdentityTransformer(Transformer[ClientT]):
def __init__(self, type: AppCommandOptionType) -> None:
self._type = type
@ -361,7 +362,7 @@ class IdentityTransformer(Transformer):
def type(self) -> AppCommandOptionType:
return self._type
async def transform(self, interaction: Interaction, value: Any, /) -> Any:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
return value
@ -490,7 +491,7 @@ class EnumNameTransformer(Transformer):
return self._enum[value]
class InlineTransformer(Transformer):
class InlineTransformer(Transformer[ClientT]):
def __init__(self, annotation: Any) -> None:
super().__init__()
self.annotation: Any = annotation
@ -503,7 +504,7 @@ class InlineTransformer(Transformer):
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.string
async def transform(self, interaction: Interaction, value: Any, /) -> Any:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any:
return await self.annotation.transform(interaction, value)
@ -526,7 +527,7 @@ else:
.. versionadded:: 2.0
"""
def __class_getitem__(cls, items) -> _TransformMetadata:
def __class_getitem__(cls, items) -> Transformer:
if not isinstance(items, tuple):
raise TypeError(f'expected tuple for arguments, received {items.__class__.__name__} instead')
@ -571,7 +572,7 @@ else:
await interaction.response.send_message(f'Your value is {value}', ephemeral=True)
"""
def __class_getitem__(cls, obj) -> _TransformMetadata:
def __class_getitem__(cls, obj) -> RangeTransformer:
if not isinstance(obj, tuple):
raise TypeError(f'expected tuple for arguments, received {obj.__class__.__name__} instead')
@ -612,25 +613,25 @@ else:
return transformer
class MemberTransformer(Transformer):
class MemberTransformer(Transformer[ClientT]):
@property
def type(self) -> AppCommandOptionType:
return AppCommandOptionType.user
async def transform(self, interaction: Interaction, value: Any, /) -> Member:
async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Member:
if not isinstance(value, Member):
raise TransformerError(value, self.type, self)
return value
class BaseChannelTransformer(Transformer):
class BaseChannelTransformer(Transformer[ClientT]):
def __init__(self, *channel_types: Type[Any]) -> None:
super().__init__()
if len(channel_types) == 1:
display_name = channel_types[0].__name__
types = CHANNEL_TO_TYPES[channel_types[0]]
else:
display_name = '{}, and {}'.format(', '.join(t.__name__ for t in channel_types[:-1]), channel_types[-1].__name__)
display_name = _human_join([t.__name__ for t in channel_types])
types = []
for t in channel_types:
@ -639,7 +640,7 @@ class BaseChannelTransformer(Transformer):
except KeyError:
raise TypeError('Union type of channels must be entirely made up of channels') from None
self._types: Tuple[Type[Any]] = channel_types
self._types: Tuple[Type[Any], ...] = channel_types
self._channel_types: List[ChannelType] = types
self._display_name = display_name
@ -655,22 +656,22 @@ class BaseChannelTransformer(Transformer):
def channel_types(self) -> List[ChannelType]:
return self._channel_types
async def transform(self, interaction: Interaction, value: Any, /):
async def transform(self, interaction: Interaction[ClientT], value: Any, /):
resolved = value.resolve()
if resolved is None or not isinstance(resolved, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return resolved
class RawChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any, /):
class RawChannelTransformer(BaseChannelTransformer[ClientT]):
async def transform(self, interaction: Interaction[ClientT], value: Any, /):
if not isinstance(value, self._types):
raise TransformerError(value, AppCommandOptionType.channel, self)
return value
class UnionChannelTransformer(BaseChannelTransformer):
async def transform(self, interaction: Interaction, value: Any, /):
class UnionChannelTransformer(BaseChannelTransformer[ClientT]):
async def transform(self, interaction: Interaction[ClientT], value: Any, /):
if isinstance(value, self._types):
return value
@ -688,6 +689,7 @@ CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
ChannelType.news,
ChannelType.category,
ChannelType.forum,
ChannelType.media,
],
GuildChannel: [
ChannelType.stage_voice,
@ -696,6 +698,7 @@ CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
ChannelType.news,
ChannelType.category,
ChannelType.forum,
ChannelType.media,
],
AppCommandThread: [ChannelType.news_thread, ChannelType.private_thread, ChannelType.public_thread],
Thread: [ChannelType.news_thread, ChannelType.private_thread, ChannelType.public_thread],
@ -703,7 +706,7 @@ CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = {
VoiceChannel: [ChannelType.voice],
TextChannel: [ChannelType.text, ChannelType.news],
CategoryChannel: [ChannelType.category],
ForumChannel: [ChannelType.forum],
ForumChannel: [ChannelType.forum, ChannelType.media],
}
BUILT_IN_TRANSFORMERS: Dict[Any, Transformer] = {
@ -750,7 +753,7 @@ def get_supported_annotation(
try:
return (_mapping[annotation], MISSING, True)
except KeyError:
except (KeyError, TypeError):
pass
if isinstance(annotation, Transformer):
@ -781,11 +784,11 @@ def get_supported_annotation(
# Check if there's an origin
origin = getattr(annotation, '__origin__', None)
if origin is Literal:
args = annotation.__args__ # type: ignore
args = annotation.__args__
return (LiteralTransformer(args), MISSING, True)
if origin is Choice:
arg = annotation.__args__[0] # type: ignore
arg = annotation.__args__[0]
return (ChoiceTransformer(arg), MISSING, True)
if origin is not Union:
@ -793,7 +796,7 @@ def get_supported_annotation(
raise TypeError(f'unsupported type annotation {annotation!r}')
default = MISSING
args = annotation.__args__ # type: ignore
args = annotation.__args__
if args[-1] is _none:
if len(args) == 2:
underlying = args[0]

2
discord/app_commands/translator.py

@ -109,7 +109,7 @@ class TranslationContext(Generic[_L, _D]):
def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None:
...
def __init__(self, location: _L, data: _D) -> None:
def __init__(self, location: _L, data: _D) -> None: # type: ignore # pyright doesn't like the overloads
self.location: _L = location
self.data: _D = data

78
discord/app_commands/tree.py

@ -58,6 +58,7 @@ from .errors import (
CommandSyncFailure,
MissingApplicationID,
)
from .installs import AppCommandContext, AppInstallationType
from .translator import Translator, locale_str
from ..errors import ClientException, HTTPException
from ..enums import AppCommandType, InteractionType
@ -72,7 +73,7 @@ if TYPE_CHECKING:
from .commands import ContextMenuCallback, CommandCallback, P, T
ErrorFunc = Callable[
[Interaction, AppCommandError],
[Interaction[ClientT], AppCommandError],
Coroutine[Any, Any, Any],
]
@ -121,9 +122,26 @@ class CommandTree(Generic[ClientT]):
to find the guild-specific ``/ping`` command it will fall back to the global ``/ping`` command.
This has the potential to raise more :exc:`~discord.app_commands.CommandSignatureMismatch` errors
than usual. Defaults to ``True``.
allowed_contexts: :class:`~discord.app_commands.AppCommandContext`
The default allowed contexts that applies to all commands in this tree.
Note that you can override this on a per command basis.
.. versionadded:: 2.4
allowed_installs: :class:`~discord.app_commands.AppInstallationType`
The default allowed install locations that apply to all commands in this tree.
Note that you can override this on a per command basis.
.. versionadded:: 2.4
"""
def __init__(self, client: ClientT, *, fallback_to_global: bool = True):
def __init__(
self,
client: ClientT,
*,
fallback_to_global: bool = True,
allowed_contexts: AppCommandContext = MISSING,
allowed_installs: AppInstallationType = MISSING,
):
self.client: ClientT = client
self._http = client.http
self._state = client._connection
@ -133,6 +151,8 @@ class CommandTree(Generic[ClientT]):
self._state._command_tree = self
self.fallback_to_global: bool = fallback_to_global
self.allowed_contexts = AppCommandContext() if allowed_contexts is MISSING else allowed_contexts
self.allowed_installs = AppInstallationType() if allowed_installs is MISSING else allowed_installs
self._guild_commands: Dict[int, Dict[str, Union[Command, Group]]] = {}
self._global_commands: Dict[str, Union[Command, Group]] = {}
# (name, guild_id, command_type): Command
@ -287,10 +307,24 @@ class CommandTree(Generic[ClientT]):
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to add the command to. If not given or ``None`` then it
becomes a global command instead.
.. note ::
Due to a Discord limitation, this keyword argument cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
guilds: List[:class:`~discord.abc.Snowflake`]
The list of guilds to add the command to. This cannot be mixed
with the ``guild`` parameter. If no guilds are given at all
then it becomes a global command instead.
.. note ::
Due to a Discord limitation, this keyword argument cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
override: :class:`bool`
Whether to override a command with the same name. If ``False``
an exception is raised. Default is ``False``.
@ -722,7 +756,7 @@ class CommandTree(Generic[ClientT]):
else:
guild_id = None if guild is None else guild.id
value = type.value
for ((_, g, t), command) in self._context_menus.items():
for (_, g, t), command in self._context_menus.items():
if g == guild_id and t == value:
yield command
@ -799,7 +833,7 @@ class CommandTree(Generic[ClientT]):
else:
_log.error('Ignoring exception in command tree', exc_info=error)
def error(self, coro: ErrorFunc) -> ErrorFunc:
def error(self, coro: ErrorFunc[ClientT]) -> ErrorFunc[ClientT]:
"""A decorator that registers a coroutine as a local error handler.
This must match the signature of the :meth:`on_error` callback.
@ -825,7 +859,7 @@ class CommandTree(Generic[ClientT]):
if len(params) != 2:
raise TypeError('error handler must have 2 parameters')
self.on_error = coro
self.on_error = coro # type: ignore
return coro
def command(
@ -857,10 +891,24 @@ class CommandTree(Generic[ClientT]):
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to add the command to. If not given or ``None`` then it
becomes a global command instead.
.. note ::
Due to a Discord limitation, this keyword argument cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
guilds: List[:class:`~discord.abc.Snowflake`]
The list of guilds to add the command to. This cannot be mixed
with the ``guild`` parameter. If no guilds are given at all
then it becomes a global command instead.
.. note ::
Due to a Discord limitation, this keyword argument cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
auto_locale_strings: :class:`bool`
If this is set to ``True``, then all translatable strings will implicitly
be wrapped into :class:`locale_str` rather than :class:`str`. This could
@ -940,10 +988,24 @@ class CommandTree(Generic[ClientT]):
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to add the command to. If not given or ``None`` then it
becomes a global command instead.
.. note ::
Due to a Discord limitation, this keyword argument cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
guilds: List[:class:`~discord.abc.Snowflake`]
The list of guilds to add the command to. This cannot be mixed
with the ``guild`` parameter. If no guilds are given at all
then it becomes a global command instead.
.. note ::
Due to a Discord limitation, this keyword argument cannot be used in conjunction with
contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types
(e.g. :func:`.app_commands.allowed_installs`).
auto_locale_strings: :class:`bool`
If this is set to ``True``, then all translatable strings will implicitly
be wrapped into :class:`locale_str` rather than :class:`str`. This could
@ -1058,9 +1120,9 @@ class CommandTree(Generic[ClientT]):
translator = self.translator
if translator:
payload = [await command.get_translated_payload(translator) for command in commands]
payload = [await command.get_translated_payload(self, translator) for command in commands]
else:
payload = [command.to_dict() for command in commands]
payload = [command.to_dict(self) for command in commands]
try:
if guild is None:
@ -1240,7 +1302,7 @@ class CommandTree(Generic[ClientT]):
await command._invoke_autocomplete(interaction, focused, namespace)
except Exception:
# Suppress exception since it can't be handled anyway.
pass
_log.exception('Ignoring exception in autocomplete for %r', command.qualified_name)
return

327
discord/appinfo.py

@ -24,20 +24,24 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import List, TYPE_CHECKING, Optional
from typing import List, TYPE_CHECKING, Literal, Optional
from . import utils
from .asset import Asset
from .flags import ApplicationFlags
from .permissions import Permissions
from .utils import MISSING
if TYPE_CHECKING:
from typing import Dict, Any
from .guild import Guild
from .types.appinfo import (
AppInfo as AppInfoPayload,
PartialAppInfo as PartialAppInfoPayload,
Team as TeamPayload,
InstallParams as InstallParamsPayload,
AppIntegrationTypeConfig as AppIntegrationTypeConfigPayload,
)
from .user import User
from .state import ConnectionState
@ -46,6 +50,7 @@ __all__ = (
'AppInfo',
'PartialAppInfo',
'AppInstallParams',
'IntegrationTypeConfig',
)
@ -131,6 +136,23 @@ class AppInfo:
a verification method in the guild's role verification configuration.
.. versionadded:: 2.2
interactions_endpoint_url: Optional[:class:`str`]
The interactions endpoint url of the application to receive interactions over this endpoint rather than
over the gateway, if configured.
.. versionadded:: 2.4
redirect_uris: List[:class:`str`]
A list of authentication redirect URIs.
.. versionadded:: 2.4
approximate_guild_count: :class:`int`
The approximate count of the guilds the bot was added to.
.. versionadded:: 2.4
approximate_user_install_count: Optional[:class:`int`]
The approximate count of the user-level installations the bot has.
.. versionadded:: 2.5
"""
__slots__ = (
@ -156,6 +178,11 @@ class AppInfo:
'custom_install_url',
'install_params',
'role_connections_verification_url',
'interactions_endpoint_url',
'redirect_uris',
'approximate_guild_count',
'approximate_user_install_count',
'_integration_types_config',
)
def __init__(self, state: ConnectionState, data: AppInfoPayload):
@ -166,7 +193,7 @@ class AppInfo:
self.name: str = data['name']
self.description: str = data['description']
self._icon: Optional[str] = data['icon']
self.rpc_origins: List[str] = data['rpc_origins']
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
self.bot_public: bool = data['bot_public']
self.bot_require_code_grant: bool = data['bot_require_code_grant']
self.owner: User = state.create_user(data['owner'])
@ -190,6 +217,13 @@ class AppInfo:
params = data.get('install_params')
self.install_params: Optional[AppInstallParams] = AppInstallParams(params) if params else None
self.interactions_endpoint_url: Optional[str] = data.get('interactions_endpoint_url')
self.redirect_uris: List[str] = data.get('redirect_uris', [])
self.approximate_guild_count: int = data.get('approximate_guild_count', 0)
self.approximate_user_install_count: Optional[int] = data.get('approximate_user_install_count')
self._integration_types_config: Dict[Literal['0', '1'], AppIntegrationTypeConfigPayload] = data.get(
'integration_types_config', {}
)
def __repr__(self) -> str:
return (
@ -232,6 +266,236 @@ class AppInfo:
"""
return ApplicationFlags._from_value(self._flags)
@property
def guild_integration_config(self) -> Optional[IntegrationTypeConfig]:
"""Optional[:class:`IntegrationTypeConfig`]: The default settings for the
application's installation context in a guild.
.. versionadded:: 2.5
"""
if not self._integration_types_config:
return None
try:
return IntegrationTypeConfig(self._integration_types_config['0'])
except KeyError:
return None
@property
def user_integration_config(self) -> Optional[IntegrationTypeConfig]:
"""Optional[:class:`IntegrationTypeConfig`]: The default settings for the
application's installation context as a user.
.. versionadded:: 2.5
"""
if not self._integration_types_config:
return None
try:
return IntegrationTypeConfig(self._integration_types_config['1'])
except KeyError:
return None
async def edit(
self,
*,
reason: Optional[str] = MISSING,
custom_install_url: Optional[str] = MISSING,
description: Optional[str] = MISSING,
role_connections_verification_url: Optional[str] = MISSING,
install_params_scopes: Optional[List[str]] = MISSING,
install_params_permissions: Optional[Permissions] = MISSING,
flags: Optional[ApplicationFlags] = MISSING,
icon: Optional[bytes] = MISSING,
cover_image: Optional[bytes] = MISSING,
interactions_endpoint_url: Optional[str] = MISSING,
tags: Optional[List[str]] = MISSING,
guild_install_scopes: Optional[List[str]] = MISSING,
guild_install_permissions: Optional[Permissions] = MISSING,
user_install_scopes: Optional[List[str]] = MISSING,
user_install_permissions: Optional[Permissions] = MISSING,
) -> AppInfo:
r"""|coro|
Edits the application info.
.. versionadded:: 2.4
Parameters
----------
custom_install_url: Optional[:class:`str`]
The new custom authorization URL for the application. Can be ``None`` to remove the URL.
description: Optional[:class:`str`]
The new application description. Can be ``None`` to remove the description.
role_connections_verification_url: Optional[:class:`str`]
The new applications connection verification URL which will render the application
as a verification method in the guilds role verification configuration. Can be ``None`` to remove the URL.
install_params_scopes: Optional[List[:class:`str`]]
The new list of :ddocs:`OAuth2 scopes <topics/oauth2#shared-resources-oauth2-scopes>` of
the :attr:`~install_params`. Can be ``None`` to remove the scopes.
install_params_permissions: Optional[:class:`Permissions`]
The new permissions of the :attr:`~install_params`. Can be ``None`` to remove the permissions.
flags: Optional[:class:`ApplicationFlags`]
The new applications flags. Only limited intent flags (:attr:`~ApplicationFlags.gateway_presence_limited`,
:attr:`~ApplicationFlags.gateway_guild_members_limited`, :attr:`~ApplicationFlags.gateway_message_content_limited`)
can be edited. Can be ``None`` to remove the flags.
.. warning::
Editing the limited intent flags leads to the termination of the bot.
icon: Optional[:class:`bytes`]
The new applications icon as a :term:`py:bytes-like object`. Can be ``None`` to remove the icon.
cover_image: Optional[:class:`bytes`]
The new applications cover image as a :term:`py:bytes-like object` on a store embed.
The cover image is only available if the application is a game sold on Discord.
Can be ``None`` to remove the image.
interactions_endpoint_url: Optional[:class:`str`]
The new interactions endpoint url of the application to receive interactions over this endpoint rather than
over the gateway. Can be ``None`` to remove the URL.
tags: Optional[List[:class:`str`]]
The new list of tags describing the functionality of the application. Can be ``None`` to remove the tags.
guild_install_scopes: Optional[List[:class:`str`]]
The new list of :ddocs:`OAuth2 scopes <topics/oauth2#shared-resources-oauth2-scopes>` of
the default guild installation context. Can be ``None`` to remove the scopes.
.. versionadded: 2.5
guild_install_permissions: Optional[:class:`Permissions`]
The new permissions of the default guild installation context. Can be ``None`` to remove the permissions.
.. versionadded: 2.5
user_install_scopes: Optional[List[:class:`str`]]
The new list of :ddocs:`OAuth2 scopes <topics/oauth2#shared-resources-oauth2-scopes>` of
the default user installation context. Can be ``None`` to remove the scopes.
.. versionadded: 2.5
user_install_permissions: Optional[:class:`Permissions`]
The new permissions of the default user installation context. Can be ``None`` to remove the permissions.
.. versionadded: 2.5
reason: Optional[:class:`str`]
The reason for editing the application. Shows up on the audit log.
Raises
-------
HTTPException
Editing the application failed
ValueError
The image format passed in to ``icon`` or ``cover_image`` is invalid. This is also raised
when ``install_params_scopes`` and ``install_params_permissions`` are incompatible with each other,
or when ``guild_install_scopes`` and ``guild_install_permissions`` are incompatible with each other.
Returns
-------
:class:`AppInfo`
The newly updated application info.
"""
payload: Dict[str, Any] = {}
if custom_install_url is not MISSING:
payload['custom_install_url'] = custom_install_url
if description is not MISSING:
payload['description'] = description
if role_connections_verification_url is not MISSING:
payload['role_connections_verification_url'] = role_connections_verification_url
if install_params_scopes is not MISSING:
install_params: Optional[Dict[str, Any]] = {}
if install_params_scopes is None:
install_params = None
else:
if "bot" not in install_params_scopes and install_params_permissions is not MISSING:
raise ValueError("'bot' must be in install_params_scopes if install_params_permissions is set")
install_params['scopes'] = install_params_scopes
if install_params_permissions is MISSING:
install_params['permissions'] = 0
else:
if install_params_permissions is None:
install_params['permissions'] = 0
else:
install_params['permissions'] = install_params_permissions.value
payload['install_params'] = install_params
else:
if install_params_permissions is not MISSING:
raise ValueError('install_params_scopes must be set if install_params_permissions is set')
if flags is not MISSING:
if flags is None:
payload['flags'] = flags
else:
payload['flags'] = flags.value
if icon is not MISSING:
if icon is None:
payload['icon'] = icon
else:
payload['icon'] = utils._bytes_to_base64_data(icon)
if cover_image is not MISSING:
if cover_image is None:
payload['cover_image'] = cover_image
else:
payload['cover_image'] = utils._bytes_to_base64_data(cover_image)
if interactions_endpoint_url is not MISSING:
payload['interactions_endpoint_url'] = interactions_endpoint_url
if tags is not MISSING:
payload['tags'] = tags
integration_types_config: Dict[str, Any] = {}
if guild_install_scopes is not MISSING or guild_install_permissions is not MISSING:
guild_install_params: Optional[Dict[str, Any]] = {}
if guild_install_scopes in (None, MISSING):
guild_install_scopes = []
if 'bot' not in guild_install_scopes and guild_install_permissions is not MISSING:
raise ValueError("'bot' must be in guild_install_scopes if guild_install_permissions is set")
if guild_install_permissions in (None, MISSING):
guild_install_params['permissions'] = 0
else:
guild_install_params['permissions'] = guild_install_permissions.value
guild_install_params['scopes'] = guild_install_scopes
integration_types_config['0'] = {'oauth2_install_params': guild_install_params or None}
else:
if guild_install_permissions is not MISSING:
raise ValueError('guild_install_scopes must be set if guild_install_permissions is set')
if user_install_scopes is not MISSING or user_install_permissions is not MISSING:
user_install_params: Optional[Dict[str, Any]] = {}
if user_install_scopes in (None, MISSING):
user_install_scopes = []
if 'bot' not in user_install_scopes and user_install_permissions is not MISSING:
raise ValueError("'bot' must be in user_install_scopes if user_install_permissions is set")
if user_install_permissions in (None, MISSING):
user_install_params['permissions'] = 0
else:
user_install_params['permissions'] = user_install_permissions.value
user_install_params['scopes'] = user_install_scopes
integration_types_config['1'] = {'oauth2_install_params': user_install_params or None}
else:
if user_install_permissions is not MISSING:
raise ValueError('user_install_scopes must be set if user_install_permissions is set')
if integration_types_config:
payload['integration_types_config'] = integration_types_config
data = await self._state.http.edit_application_info(reason=reason, payload=payload)
return AppInfo(data=data, state=self._state)
class PartialAppInfo:
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
@ -255,6 +519,24 @@ class PartialAppInfo:
The application's terms of service URL, if set.
privacy_policy_url: Optional[:class:`str`]
The application's privacy policy URL, if set.
approximate_guild_count: :class:`int`
The approximate count of the guilds the bot was added to.
.. versionadded:: 2.3
redirect_uris: List[:class:`str`]
A list of authentication redirect URIs.
.. versionadded:: 2.3
interactions_endpoint_url: Optional[:class:`str`]
The interactions endpoint url of the application to receive interactions over this endpoint rather than
over the gateway, if configured.
.. versionadded:: 2.3
role_connections_verification_url: Optional[:class:`str`]
The application's connection verification URL which will render the application as
a verification method in the guild's role verification configuration.
.. versionadded:: 2.3
"""
__slots__ = (
@ -268,6 +550,11 @@ class PartialAppInfo:
'privacy_policy_url',
'_icon',
'_flags',
'_cover_image',
'approximate_guild_count',
'redirect_uris',
'interactions_endpoint_url',
'role_connections_verification_url',
)
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
@ -276,11 +563,16 @@ class PartialAppInfo:
self.name: str = data['name']
self._icon: Optional[str] = data.get('icon')
self._flags: int = data.get('flags', 0)
self._cover_image: Optional[str] = data.get('cover_image')
self.description: str = data['description']
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
self.verify_key: str = data['verify_key']
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
self.approximate_guild_count: int = data.get('approximate_guild_count', 0)
self.redirect_uris: List[str] = data.get('redirect_uris', [])
self.interactions_endpoint_url: Optional[str] = data.get('interactions_endpoint_url')
self.role_connections_verification_url: Optional[str] = data.get('role_connections_verification_url')
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'
@ -292,6 +584,18 @@ class PartialAppInfo:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')
@property
def cover_image(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the cover image of the application's default rich presence.
This is only available if the application is a game sold on Discord.
.. versionadded:: 2.3
"""
if self._cover_image is None:
return None
return Asset._from_cover_image(self._state, self.id, self._cover_image)
@property
def flags(self) -> ApplicationFlags:
""":class:`ApplicationFlags`: The application's flags.
@ -320,3 +624,22 @@ class AppInstallParams:
def __init__(self, data: InstallParamsPayload) -> None:
self.scopes: List[str] = data.get('scopes', [])
self.permissions: Permissions = Permissions(int(data['permissions']))
class IntegrationTypeConfig:
"""Represents the default settings for the application's installation context.
.. versionadded:: 2.5
Attributes
----------
oauth2_install_params: Optional[:class:`AppInstallParams`]
The install params for this installation context's default in-app authorization link.
"""
def __init__(self, data: AppIntegrationTypeConfigPayload) -> None:
self.oauth2_install_params: Optional[AppInstallParams] = None
try:
self.oauth2_install_params = AppInstallParams(data['oauth2_install_params']) # type: ignore # EAFP
except KeyError:
pass

26
discord/asset.py

@ -246,6 +246,26 @@ class Asset(AssetMixin):
animated=animated,
)
@classmethod
def _from_guild_banner(cls, state: _State, guild_id: int, member_id: int, banner: str) -> Self:
animated = banner.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
state,
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/banners/{banner}.{format}?size=1024",
key=banner,
animated=animated,
)
@classmethod
def _from_avatar_decoration(cls, state: _State, avatar_decoration: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/avatar-decoration-presets/{avatar_decoration}.png?size=96',
key=avatar_decoration,
animated=True,
)
@classmethod
def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self:
return cls(
@ -420,7 +440,7 @@ class Asset(AssetMixin):
url = url.with_query(url.raw_query_string)
url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
return self.__class__(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Self:
"""Returns a new asset with the specified size.
@ -448,7 +468,7 @@ class Asset(AssetMixin):
raise ValueError('size must be a power of 2 between 16 and 4096')
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
return self.__class__(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Self:
"""Returns a new asset with the specified format.
@ -483,7 +503,7 @@ class Asset(AssetMixin):
url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path)
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
return self.__class__(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self:
"""Returns a new asset with the specified static format.

189
discord/audit_logs.py

@ -33,7 +33,7 @@ from .invite import Invite
from .mixins import Hashable
from .object import Object
from .permissions import PermissionOverwrite, Permissions
from .automod import AutoModTrigger, AutoModRuleAction, AutoModPresets, AutoModRule
from .automod import AutoModTrigger, AutoModRuleAction, AutoModRule
from .role import Role
from .emoji import Emoji
from .partial_emoji import PartialEmoji
@ -61,6 +61,7 @@ if TYPE_CHECKING:
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload,
_AuditLogChange_TriggerMetadata as AuditLogChangeTriggerMetadataPayload,
)
from .types.channel import (
PermissionOverwrite as PermissionOverwritePayload,
@ -71,9 +72,10 @@ if TYPE_CHECKING:
from .types.role import Role as RolePayload
from .types.snowflake import Snowflake
from .types.command import ApplicationCommandPermissions
from .types.automod import AutoModerationTriggerMetadata, AutoModerationAction
from .types.automod import AutoModerationAction
from .user import User
from .app_commands import AppCommand
from .webhook import Webhook
TargetType = Union[
Guild,
@ -89,6 +91,9 @@ if TYPE_CHECKING:
Object,
PartialIntegration,
AutoModRule,
ScheduledEvent,
Webhook,
AppCommand,
None,
]
@ -140,8 +145,8 @@ def _transform_applied_forum_tags(entry: AuditLogEntry, data: List[Snowflake]) -
return [Object(id=tag_id, type=ForumTag) for tag_id in data]
def _transform_overloaded_flags(entry: AuditLogEntry, data: int) -> Union[int, flags.ChannelFlags]:
# The `flags` key is definitely overloaded. Right now it's for channels and threads but
def _transform_overloaded_flags(entry: AuditLogEntry, data: int) -> Union[int, flags.ChannelFlags, flags.InviteFlags]:
# The `flags` key is definitely overloaded. Right now it's for channels, threads and invites but
# I am aware of `member.flags` and `user.flags` existing. However, this does not impact audit logs
# at the moment but better safe than sorry.
channel_audit_log_types = (
@ -152,9 +157,16 @@ def _transform_overloaded_flags(entry: AuditLogEntry, data: int) -> Union[int, f
enums.AuditLogAction.thread_update,
enums.AuditLogAction.thread_delete,
)
invite_audit_log_types = (
enums.AuditLogAction.invite_create,
enums.AuditLogAction.invite_update,
enums.AuditLogAction.invite_delete,
)
if entry.action in channel_audit_log_types:
return flags.ChannelFlags._from_value(data)
elif entry.action in invite_audit_log_types:
return flags.InviteFlags._from_value(data)
return data
@ -226,43 +238,14 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]
return _transform
def _transform_automod_trigger_metadata(
entry: AuditLogEntry, data: AutoModerationTriggerMetadata
) -> Optional[AutoModTrigger]:
if isinstance(entry.target, AutoModRule):
# Trigger type cannot be changed, so type should be the same before and after updates.
# Avoids checking which keys are in data to guess trigger type
# or returning None if data is empty.
try:
return AutoModTrigger.from_data(type=entry.target.trigger.type.value, data=data)
except Exception:
pass
# Try to infer trigger type from available keys in data
if 'presets' in data:
return AutoModTrigger(
type=enums.AutoModRuleTriggerType.keyword_preset,
presets=AutoModPresets._from_value(data['presets']), # type: ignore
allow_list=data.get('allow_list'),
)
elif 'keyword_filter' in data:
return AutoModTrigger(
type=enums.AutoModRuleTriggerType.keyword,
keyword_filter=data['keyword_filter'], # type: ignore
allow_list=data.get('allow_list'),
regex_patterns=data.get('regex_patterns'),
)
elif 'mention_total_limit' in data:
return AutoModTrigger(type=enums.AutoModRuleTriggerType.mention_spam, mention_limit=data['mention_total_limit']) # type: ignore
else:
return AutoModTrigger(type=enums.AutoModRuleTriggerType.spam)
def _transform_automod_actions(entry: AuditLogEntry, data: List[AutoModerationAction]) -> List[AutoModRuleAction]:
return [AutoModRuleAction.from_data(action) for action in data]
def _transform_default_emoji(entry: AuditLogEntry, data: str) -> PartialEmoji:
return PartialEmoji(name=data)
E = TypeVar('E', bound=enums.Enum)
@ -362,7 +345,6 @@ class AuditLogChanges:
'image_hash': ('cover_image', _transform_cover_image),
'trigger_type': (None, _enum_transformer(enums.AutoModRuleTriggerType)),
'event_type': (None, _enum_transformer(enums.AutoModRuleEventType)),
'trigger_metadata': ('trigger', _transform_automod_trigger_metadata),
'actions': (None, _transform_automod_actions),
'exempt_channels': (None, _transform_channels_or_threads),
'exempt_roles': (None, _transform_roles),
@ -370,6 +352,8 @@ class AuditLogChanges:
'available_tags': (None, _transform_forum_tags),
'flags': (None, _transform_overloaded_flags),
'default_reaction_emoji': (None, _transform_default_reaction),
'emoji_name': ('emoji', _transform_default_emoji),
'user_id': ('user', _transform_member_id)
}
# fmt: on
@ -408,6 +392,21 @@ class AuditLogChanges:
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore # new_value is a list of roles in this case
continue
# special case for automod trigger
if attr == 'trigger_metadata':
# given full metadata dict
self._handle_trigger_metadata(entry, elem, data) # type: ignore # should be trigger metadata
continue
elif entry.action is enums.AuditLogAction.automod_rule_update and attr.startswith('$'):
# on update, some trigger attributes are keys and formatted as $(add/remove)_{attribute}
action, _, trigger_attr = attr.partition('_')
# new_value should be a list of added/removed strings for keyword_filter, regex_patterns, or allow_list
if action == '$add':
self._handle_trigger_attr_update(self.before, self.after, entry, trigger_attr, elem['new_value']) # type: ignore
elif action == '$remove':
self._handle_trigger_attr_update(self.after, self.before, entry, trigger_attr, elem['new_value']) # type: ignore
continue
try:
key, transformer = self.TRANSFORMERS[attr]
except (ValueError, KeyError):
@ -484,6 +483,76 @@ class AuditLogChanges:
guild = entry.guild
diff.app_command_permissions.append(AppCommandPermissions(data=data, guild=guild, state=state))
def _handle_trigger_metadata(
self,
entry: AuditLogEntry,
data: AuditLogChangeTriggerMetadataPayload,
full_data: List[AuditLogChangePayload],
):
trigger_value: Optional[int] = None
trigger_type: Optional[enums.AutoModRuleTriggerType] = None
# try to get trigger type from before or after
trigger_type = getattr(self.before, 'trigger_type', getattr(self.after, 'trigger_type', None))
if trigger_type is None:
if isinstance(entry.target, AutoModRule):
# Trigger type cannot be changed, so it should be the same before and after updates.
# Avoids checking which keys are in data to guess trigger type
trigger_value = entry.target.trigger.type.value
else:
# found a trigger type from before or after
trigger_value = trigger_type.value
if trigger_value is None:
# try to find trigger type in the full list of changes
_elem = utils.find(lambda elem: elem['key'] == 'trigger_type', full_data)
if _elem is not None:
trigger_value = _elem.get('old_value', _elem.get('new_value')) # type: ignore # trigger type values should be int
if trigger_value is None:
# try to infer trigger_type from the keys in old or new value
combined = (data.get('old_value') or {}).keys() | (data.get('new_value') or {}).keys()
if not combined:
trigger_value = enums.AutoModRuleTriggerType.spam.value
elif 'presets' in combined:
trigger_value = enums.AutoModRuleTriggerType.keyword_preset.value
elif 'keyword_filter' in combined or 'regex_patterns' in combined:
trigger_value = enums.AutoModRuleTriggerType.keyword.value
elif 'mention_total_limit' in combined or 'mention_raid_protection_enabled' in combined:
trigger_value = enums.AutoModRuleTriggerType.mention_spam.value
else:
# some unknown type
trigger_value = -1
self.before.trigger = AutoModTrigger.from_data(trigger_value, data.get('old_value'))
self.after.trigger = AutoModTrigger.from_data(trigger_value, data.get('new_value'))
def _handle_trigger_attr_update(
self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, attr: str, data: List[str]
):
self._create_trigger(first, entry)
trigger = self._create_trigger(second, entry)
try:
# guard unexpecte non list attributes or non iterable data
getattr(trigger, attr).extend(data)
except (AttributeError, TypeError):
pass
def _create_trigger(self, diff: AuditLogDiff, entry: AuditLogEntry) -> AutoModTrigger:
# check if trigger has already been created
if not hasattr(diff, 'trigger'):
# create a trigger
if isinstance(entry.target, AutoModRule):
# get trigger type from the automod rule
trigger_type = entry.target.trigger.type
else:
# unknown trigger type
trigger_type = enums.try_enum(enums.AutoModRuleTriggerType, -1)
diff.trigger = AutoModTrigger(type=trigger_type)
return diff.trigger
class _AuditLogProxy:
def __init__(self, **kwargs: Any) -> None:
@ -521,7 +590,11 @@ class _AuditLogProxyMessageBulkDelete(_AuditLogProxy):
class _AuditLogProxyAutoModAction(_AuditLogProxy):
automod_rule_name: str
automod_rule_trigger_type: str
channel: Union[abc.GuildChannel, Thread]
channel: Optional[Union[abc.GuildChannel, Thread]]
class _AuditLogProxyMemberKickOrMemberRoleUpdate(_AuditLogProxy):
integration_type: Optional[str]
class AuditLogEntry(Hashable):
@ -580,6 +653,7 @@ class AuditLogEntry(Hashable):
integrations: Mapping[int, PartialIntegration],
app_commands: Mapping[int, AppCommand],
automod_rules: Mapping[int, AutoModRule],
webhooks: Mapping[int, Webhook],
data: AuditLogEntryPayload,
guild: Guild,
):
@ -589,6 +663,7 @@ class AuditLogEntry(Hashable):
self._integrations: Mapping[int, PartialIntegration] = integrations
self._app_commands: Mapping[int, AppCommand] = app_commands
self._automod_rules: Mapping[int, AutoModRule] = automod_rules
self._webhooks: Mapping[int, Webhook] = webhooks
self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None:
@ -608,6 +683,7 @@ class AuditLogEntry(Hashable):
_AuditLogProxyStageInstanceAction,
_AuditLogProxyMessageBulkDelete,
_AuditLogProxyAutoModAction,
_AuditLogProxyMemberKickOrMemberRoleUpdate,
Member, User, None, PartialIntegration,
Role, Object
] = None
@ -632,6 +708,10 @@ class AuditLogEntry(Hashable):
elif self.action is enums.AuditLogAction.message_bulk_delete:
# The bulk message delete action has the number of messages deleted
self.extra = _AuditLogProxyMessageBulkDelete(count=int(extra['count']))
elif self.action in (enums.AuditLogAction.kick, enums.AuditLogAction.member_role_update):
# The member kick action has a dict with some information
integration_type = extra.get('integration_type')
self.extra = _AuditLogProxyMemberKickOrMemberRoleUpdate(integration_type=integration_type)
elif self.action.name.endswith('pin'):
# the pin actions have a dict with some information
channel_id = int(extra['channel_id'])
@ -644,13 +724,19 @@ class AuditLogEntry(Hashable):
or self.action is enums.AuditLogAction.automod_flag_message
or self.action is enums.AuditLogAction.automod_timeout_member
):
channel_id = int(extra['channel_id'])
channel_id = utils._get_as_snowflake(extra, 'channel_id')
channel = None
# May be an empty string instead of None due to a Discord issue
if channel_id:
channel = self.guild.get_channel_or_thread(channel_id) or Object(id=channel_id)
self.extra = _AuditLogProxyAutoModAction(
automod_rule_name=extra['auto_moderation_rule_name'],
automod_rule_trigger_type=enums.try_enum(
enums.AutoModRuleTriggerType, extra['auto_moderation_rule_trigger_type']
),
channel=self.guild.get_channel_or_thread(channel_id) or Object(id=channel_id),
channel=channel,
)
elif self.action.name.startswith('overwrite_'):
@ -760,7 +846,12 @@ class AuditLogEntry(Hashable):
def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]:
return self.guild.get_channel(target_id) or Object(id=target_id)
def _convert_target_user(self, target_id: int) -> Union[Member, User, Object]:
def _convert_target_user(self, target_id: Optional[int]) -> Optional[Union[Member, User, Object]]:
# For some reason the member_disconnect and member_move action types
# do not have a non-null target_id so safeguard against that
if target_id is None:
return None
return self._get_member(target_id) or Object(id=target_id, type=Member)
def _convert_target_role(self, target_id: int) -> Union[Role, Object]:
@ -790,7 +881,13 @@ class AuditLogEntry(Hashable):
def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]:
return self._state.get_emoji(target_id) or Object(id=target_id, type=Emoji)
def _convert_target_message(self, target_id: int) -> Union[Member, User, Object]:
def _convert_target_message(self, target_id: Optional[int]) -> Optional[Union[Member, User, Object]]:
# The message_pin and message_unpin action types do not have a
# non-null target_id so safeguard against that
if target_id is None:
return None
return self._get_member(target_id) or Object(id=target_id, type=Member)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
@ -840,3 +937,9 @@ class AuditLogEntry(Hashable):
def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]:
return self._automod_rules.get(target_id) or Object(target_id, type=AutoModRule)
def _convert_target_webhook(self, target_id: int) -> Union[Webhook, Object]:
# circular import
from .webhook import Webhook
return self._webhooks.get(target_id) or Object(target_id, type=Webhook)

156
discord/automod.py

@ -25,8 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional, List, Sequence, Set, Union, Sequence
from typing import TYPE_CHECKING, Any, Dict, Optional, List, Set, Union, Sequence, overload, Literal
from .enums import AutoModRuleTriggerType, AutoModRuleActionType, AutoModRuleEventType, try_enum
from .flags import AutoModPresets
@ -59,6 +58,9 @@ __all__ = (
class AutoModRuleAction:
"""Represents an auto moderation's rule action.
.. note::
Only one of ``channel_id``, ``duration``, or ``custom_message`` can be used.
.. versionadded:: 2.0
Attributes
@ -73,40 +75,114 @@ class AutoModRuleAction:
The duration of the timeout to apply, if any.
Has a maximum of 28 days.
Passing this sets :attr:`type` to :attr:`~AutoModRuleActionType.timeout`.
custom_message: Optional[:class:`str`]
A custom message which will be shown to a user when their message is blocked.
Passing this sets :attr:`type` to :attr:`~AutoModRuleActionType.block_message`.
.. versionadded:: 2.2
"""
__slots__ = ('type', 'channel_id', 'duration')
__slots__ = ('type', 'channel_id', 'duration', 'custom_message')
@overload
def __init__(self, *, channel_id: int = ...) -> None:
...
@overload
def __init__(self, *, type: Literal[AutoModRuleActionType.send_alert_message], channel_id: int = ...) -> None:
...
@overload
def __init__(self, *, duration: datetime.timedelta = ...) -> None:
...
@overload
def __init__(self, *, type: Literal[AutoModRuleActionType.timeout], duration: datetime.timedelta = ...) -> None:
...
@overload
def __init__(self, *, custom_message: str = ...) -> None:
...
@overload
def __init__(self, *, type: Literal[AutoModRuleActionType.block_message]) -> None:
...
@overload
def __init__(self, *, type: Literal[AutoModRuleActionType.block_message], custom_message: Optional[str] = ...) -> None:
...
@overload
def __init__(
self,
*,
type: Optional[AutoModRuleActionType] = ...,
channel_id: Optional[int] = ...,
duration: Optional[datetime.timedelta] = ...,
custom_message: Optional[str] = ...,
) -> None:
...
def __init__(
self,
*,
type: Optional[AutoModRuleActionType] = None,
channel_id: Optional[int] = None,
duration: Optional[datetime.timedelta] = None,
custom_message: Optional[str] = None,
) -> None:
if sum(v is None for v in (channel_id, duration, custom_message)) < 2:
raise ValueError('Only one of channel_id, duration, or custom_message can be passed.')
def __init__(self, *, channel_id: Optional[int] = None, duration: Optional[datetime.timedelta] = None) -> None:
self.channel_id: Optional[int] = channel_id
self.duration: Optional[datetime.timedelta] = duration
if channel_id and duration:
raise ValueError('Please provide only one of ``channel`` or ``duration``')
self.type: AutoModRuleActionType
self.channel_id: Optional[int] = None
self.duration: Optional[datetime.timedelta] = None
self.custom_message: Optional[str] = None
if channel_id:
if type is not None:
self.type = type
elif channel_id is not None:
self.type = AutoModRuleActionType.send_alert_message
elif duration:
elif duration is not None:
self.type = AutoModRuleActionType.timeout
else:
self.type = AutoModRuleActionType.block_message
if self.type is AutoModRuleActionType.send_alert_message:
if channel_id is None:
raise ValueError('channel_id cannot be None if type is send_alert_message')
self.channel_id = channel_id
if self.type is AutoModRuleActionType.timeout:
if duration is None:
raise ValueError('duration cannot be None set if type is timeout')
self.duration = duration
if self.type is AutoModRuleActionType.block_message:
self.custom_message = custom_message
def __repr__(self) -> str:
return f'<AutoModRuleAction type={self.type.value} channel={self.channel_id} duration={self.duration}>'
@classmethod
def from_data(cls, data: AutoModerationActionPayload) -> Self:
type_ = try_enum(AutoModRuleActionType, data['type'])
if data['type'] == AutoModRuleActionType.timeout.value:
duration_seconds = data['metadata']['duration_seconds']
return cls(duration=datetime.timedelta(seconds=duration_seconds))
elif data['type'] == AutoModRuleActionType.send_alert_message.value:
channel_id = int(data['metadata']['channel_id'])
return cls(channel_id=channel_id)
return cls()
elif data['type'] == AutoModRuleActionType.block_message.value:
custom_message = data.get('metadata', {}).get('custom_message')
return cls(type=AutoModRuleActionType.block_message, custom_message=custom_message)
return cls(type=AutoModRuleActionType.block_member_interactions)
def to_dict(self) -> Dict[str, Any]:
ret = {'type': self.type.value, 'metadata': {}}
if self.type is AutoModRuleActionType.timeout:
if self.type is AutoModRuleActionType.block_message and self.custom_message is not None:
ret['metadata'] = {'custom_message': self.custom_message}
elif self.type is AutoModRuleActionType.timeout:
ret['metadata'] = {'duration_seconds': int(self.duration.total_seconds())} # type: ignore # duration cannot be None here
elif self.type is AutoModRuleActionType.send_alert_message:
ret['metadata'] = {'channel_id': str(self.channel_id)}
@ -128,7 +204,11 @@ class AutoModTrigger:
+-----------------------------------------------+------------------------------------------------+
| :attr:`AutoModRuleTriggerType.keyword_preset` | :attr:`presets`\, :attr:`allow_list` |
+-----------------------------------------------+------------------------------------------------+
| :attr:`AutoModRuleTriggerType.mention_spam` | :attr:`mention_limit` |
| :attr:`AutoModRuleTriggerType.mention_spam` | :attr:`mention_limit`, |
| | :attr:`mention_raid_protection` |
+-----------------------------------------------+------------------------------------------------+
| :attr:`AutoModRuleTriggerType.member_profile` | :attr:`keyword_filter`, :attr:`regex_patterns`,|
| | :attr:`allow_list` |
+-----------------------------------------------+------------------------------------------------+
.. versionadded:: 2.0
@ -138,14 +218,14 @@ class AutoModTrigger:
type: :class:`AutoModRuleTriggerType`
The type of trigger.
keyword_filter: List[:class:`str`]
The list of strings that will trigger the keyword filter. Maximum of 1000.
Keywords can only be up to 30 characters in length.
The list of strings that will trigger the filter.
Maximum of 1000. Keywords can only be up to 60 characters in length.
This could be combined with :attr:`regex_patterns`.
regex_patterns: List[:class:`str`]
The regex pattern that will trigger the filter. The syntax is based off of
`Rust's regex syntax <https://docs.rs/regex/latest/regex/#syntax>`_.
Maximum of 10. Regex strings can only be up to 250 characters in length.
Maximum of 10. Regex strings can only be up to 260 characters in length.
This could be combined with :attr:`keyword_filter` and/or :attr:`allow_list`
@ -153,10 +233,15 @@ class AutoModTrigger:
presets: :class:`AutoModPresets`
The presets used with the preset keyword filter.
allow_list: List[:class:`str`]
The list of words that are exempt from the commonly flagged words.
The list of words that are exempt from the commonly flagged words. Maximum of 100.
Keywords can only be up to 60 characters in length.
mention_limit: :class:`int`
The total number of user and role mentions a message can contain.
Has a maximum of 50.
mention_raid_protection: :class:`bool`
Whether mention raid protection is enabled or not.
.. versionadded:: 2.4
"""
__slots__ = (
@ -166,6 +251,7 @@ class AutoModTrigger:
'allow_list',
'mention_limit',
'regex_patterns',
'mention_raid_protection',
)
def __init__(
@ -177,9 +263,13 @@ class AutoModTrigger:
allow_list: Optional[List[str]] = None,
mention_limit: Optional[int] = None,
regex_patterns: Optional[List[str]] = None,
mention_raid_protection: Optional[bool] = None,
) -> None:
if type is None and sum(arg is not None for arg in (keyword_filter or regex_patterns, presets, mention_limit)) > 1:
raise ValueError('Please pass only one of keyword_filter, regex_patterns, presets, or mention_limit.')
unique_args = (keyword_filter or regex_patterns, presets, mention_limit or mention_raid_protection)
if type is None and sum(arg is not None for arg in unique_args) > 1:
raise ValueError(
'Please pass only one of keyword_filter/regex_patterns, presets, or mention_limit/mention_raid_protection.'
)
if type is not None:
self.type = type
@ -187,17 +277,18 @@ class AutoModTrigger:
self.type = AutoModRuleTriggerType.keyword
elif presets is not None:
self.type = AutoModRuleTriggerType.keyword_preset
elif mention_limit is not None:
elif mention_limit is not None or mention_raid_protection is not None:
self.type = AutoModRuleTriggerType.mention_spam
else:
raise ValueError(
'Please pass the trigger type explicitly if not using keyword_filter, presets, or mention_limit.'
'Please pass the trigger type explicitly if not using keyword_filter, regex_patterns, presets, mention_limit, or mention_raid_protection.'
)
self.keyword_filter: List[str] = keyword_filter if keyword_filter is not None else []
self.presets: AutoModPresets = presets if presets is not None else AutoModPresets()
self.allow_list: List[str] = allow_list if allow_list is not None else []
self.mention_limit: int = mention_limit if mention_limit is not None else 0
self.mention_raid_protection: bool = mention_raid_protection if mention_raid_protection is not None else False
self.regex_patterns: List[str] = regex_patterns if regex_patterns is not None else []
def __repr__(self) -> str:
@ -213,7 +304,7 @@ class AutoModTrigger:
type_ = try_enum(AutoModRuleTriggerType, type)
if data is None:
return cls(type=type_)
elif type_ is AutoModRuleTriggerType.keyword:
elif type_ in (AutoModRuleTriggerType.keyword, AutoModRuleTriggerType.member_profile):
return cls(
type=type_,
keyword_filter=data.get('keyword_filter'),
@ -225,12 +316,16 @@ class AutoModTrigger:
type=type_, presets=AutoModPresets._from_value(data.get('presets', [])), allow_list=data.get('allow_list')
)
elif type_ is AutoModRuleTriggerType.mention_spam:
return cls(type=type_, mention_limit=data.get('mention_total_limit'))
return cls(
type=type_,
mention_limit=data.get('mention_total_limit'),
mention_raid_protection=data.get('mention_raid_protection_enabled'),
)
else:
return cls(type=type_)
def to_metadata_dict(self) -> Optional[Dict[str, Any]]:
if self.type is AutoModRuleTriggerType.keyword:
if self.type in (AutoModRuleTriggerType.keyword, AutoModRuleTriggerType.member_profile):
return {
'keyword_filter': self.keyword_filter,
'regex_patterns': self.regex_patterns,
@ -239,7 +334,10 @@ class AutoModTrigger:
elif self.type is AutoModRuleTriggerType.keyword_preset:
return {'presets': self.presets.to_array(), 'allow_list': self.allow_list}
elif self.type is AutoModRuleTriggerType.mention_spam:
return {'mention_total_limit': self.mention_limit}
return {
'mention_total_limit': self.mention_limit,
'mention_raid_protection_enabled': self.mention_raid_protection,
}
class AutoModRule:
@ -265,6 +363,8 @@ class AutoModRule:
The IDs of the roles that are exempt from the rule.
exempt_channel_ids: Set[:class:`int`]
The IDs of the channels that are exempt from the rule.
event_type: :class:`AutoModRuleEventType`
The type of event that will trigger the the rule.
"""
__slots__ = (
@ -418,7 +518,7 @@ class AutoModRule:
payload['name'] = name
if event_type is not MISSING:
payload['event_type'] = event_type
payload['event_type'] = event_type.value
if trigger is not MISSING:
trigger_metadata = trigger.to_metadata_dict()
@ -441,7 +541,7 @@ class AutoModRule:
**payload,
)
return AutoModRule(data=data, guild=self.guild, state=self._state)
return self.__class__(data=data, guild=self.guild, state=self._state)
async def delete(self, *, reason: str = MISSING) -> None:
"""|coro|

596
discord/channel.py

@ -47,7 +47,16 @@ import datetime
import discord.abc
from .scheduled_event import ScheduledEvent
from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, ForumLayoutType, PrivacyLevel, try_enum, VideoQualityMode, EntityType
from .enums import (
ChannelType,
ForumLayoutType,
ForumOrderType,
PrivacyLevel,
try_enum,
VideoQualityMode,
EntityType,
VoiceChannelEffectAnimationType,
)
from .mixins import Hashable
from . import utils
from .utils import MISSING
@ -56,8 +65,10 @@ from .errors import ClientException
from .stage_instance import StageInstance
from .threads import Thread
from .partial_emoji import _EmojiTag, PartialEmoji
from .flags import ChannelFlags
from .flags import ChannelFlags, MessageFlags
from .http import handle_message_parameters
from .object import Object
from .soundboard import BaseSoundboardSound, SoundboardDefaultSound
__all__ = (
'TextChannel',
@ -69,6 +80,8 @@ __all__ = (
'ForumChannel',
'GroupChannel',
'PartialMessageable',
'VoiceChannelEffect',
'VoiceChannelSoundEffect',
)
if TYPE_CHECKING:
@ -76,7 +89,6 @@ if TYPE_CHECKING:
from .types.threads import ThreadArchiveDuration
from .role import Role
from .object import Object
from .member import Member, VoiceState
from .abc import Snowflake, SnowflakeTime
from .embeds import Embed
@ -98,9 +110,13 @@ if TYPE_CHECKING:
CategoryChannel as CategoryChannelPayload,
GroupDMChannel as GroupChannelPayload,
ForumChannel as ForumChannelPayload,
MediaChannel as MediaChannelPayload,
ForumTag as ForumTagPayload,
VoiceChannelEffect as VoiceChannelEffectPayload,
)
from .types.snowflake import SnowflakeList
from .types.soundboard import BaseSoundboardSound as BaseSoundboardSoundPayload
from .soundboard import SoundboardSound
OverwriteKeyT = TypeVar('OverwriteKeyT', Role, BaseUser, Object, Union[Role, Member, Object])
@ -110,6 +126,121 @@ class ThreadWithMessage(NamedTuple):
message: Message
class VoiceChannelEffectAnimation(NamedTuple):
id: int
type: VoiceChannelEffectAnimationType
class VoiceChannelSoundEffect(BaseSoundboardSound):
"""Represents a Discord voice channel sound effect.
.. versionadded:: 2.5
.. container:: operations
.. describe:: x == y
Checks if two sound effects are equal.
.. describe:: x != y
Checks if two sound effects are not equal.
.. describe:: hash(x)
Returns the sound effect's hash.
Attributes
------------
id: :class:`int`
The ID of the sound.
volume: :class:`float`
The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%).
"""
__slots__ = ('_state',)
def __init__(self, *, state: ConnectionState, id: int, volume: float):
data: BaseSoundboardSoundPayload = {
'sound_id': id,
'volume': volume,
}
super().__init__(state=state, data=data)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} volume={self.volume}>"
@property
def created_at(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: Returns the snowflake's creation time in UTC.
Returns ``None`` if it's a default sound."""
if self.is_default():
return None
else:
return utils.snowflake_time(self.id)
def is_default(self) -> bool:
""":class:`bool`: Whether it's a default sound or not."""
# if it's smaller than the Discord Epoch it cannot be a snowflake
return self.id < utils.DISCORD_EPOCH
class VoiceChannelEffect:
"""Represents a Discord voice channel effect.
.. versionadded:: 2.5
Attributes
------------
channel: :class:`VoiceChannel`
The channel in which the effect is sent.
user: Optional[:class:`Member`]
The user who sent the effect. ``None`` if not found in cache.
animation: Optional[:class:`VoiceChannelEffectAnimation`]
The animation the effect has. Returns ``None`` if the effect has no animation.
emoji: Optional[:class:`PartialEmoji`]
The emoji of the effect.
sound: Optional[:class:`VoiceChannelSoundEffect`]
The sound of the effect. Returns ``None`` if it's an emoji effect.
"""
__slots__ = ('channel', 'user', 'animation', 'emoji', 'sound')
def __init__(self, *, state: ConnectionState, data: VoiceChannelEffectPayload, guild: Guild):
self.channel: VoiceChannel = guild.get_channel(int(data['channel_id'])) # type: ignore # will always be a VoiceChannel
self.user: Optional[Member] = guild.get_member(int(data['user_id']))
self.animation: Optional[VoiceChannelEffectAnimation] = None
animation_id = data.get('animation_id')
if animation_id is not None:
animation_type = try_enum(VoiceChannelEffectAnimationType, data['animation_type']) # type: ignore # cannot be None here
self.animation = VoiceChannelEffectAnimation(id=animation_id, type=animation_type)
emoji = data.get('emoji')
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None
self.sound: Optional[VoiceChannelSoundEffect] = None
sound_id: Optional[int] = utils._get_as_snowflake(data, 'sound_id')
if sound_id is not None:
sound_volume = data.get('sound_volume') or 0.0
self.sound = VoiceChannelSoundEffect(state=state, id=sound_id, volume=sound_volume)
def __repr__(self) -> str:
attrs = [
('channel', self.channel),
('user', self.user),
('animation', self.animation),
('emoji', self.emoji),
('sound', self.sound),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f"<{self.__class__.__name__} {inner}>"
def is_sound(self) -> bool:
""":class:`bool`: Whether the effect is a sound or not."""
return self.sound is not None
class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""Represents a Discord guild text channel.
@ -160,6 +291,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
The default auto archive duration in minutes for threads created in this channel.
.. versionadded:: 2.0
default_thread_slowmode_delay: :class:`int`
The default slowmode delay in seconds for threads created in this channel.
.. versionadded:: 2.3
"""
__slots__ = (
@ -176,6 +311,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
'_type',
'last_message_id',
'default_auto_archive_duration',
'default_thread_slowmode_delay',
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[TextChannelPayload, NewsChannelPayload]):
@ -206,6 +342,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
# Does this need coercion into `int`? No idea yet.
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
self.default_thread_slowmode_delay: int = data.get('default_thread_rate_limit_per_user', 0)
self._type: Literal[0, 5] = data.get('type', self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data)
@ -301,6 +438,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
category: Optional[CategoryChannel] = ...,
slowmode_delay: int = ...,
default_auto_archive_duration: ThreadArchiveDuration = ...,
default_thread_slowmode_delay: int = ...,
type: ChannelType = ...,
overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ...,
) -> TextChannel:
@ -359,7 +497,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
.. versionadded:: 2.0
default_thread_slowmode_delay: :class:`int`
The new default slowmode delay in seconds for threads created in this channel.
.. versionadded:: 2.3
Raises
------
ValueError
@ -384,9 +525,26 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
async def clone(
self,
*,
name: Optional[str] = None,
category: Optional[CategoryChannel] = None,
reason: Optional[str] = None,
) -> TextChannel:
base: Dict[Any, Any] = {
'topic': self.topic,
'nsfw': self.nsfw,
'default_auto_archive_duration': self.default_auto_archive_duration,
'default_thread_rate_limit_per_user': self.default_thread_slowmode_delay,
}
if not self.is_news():
base['rate_limit_per_user'] = self.slowmode_delay
return await self._clone_impl(
{'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason
base,
name=name,
category=category,
reason=reason,
)
async def delete_messages(self, messages: Iterable[Snowflake], *, reason: Optional[str] = None) -> None:
@ -727,7 +885,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
If ``None`` is passed then a private thread is created.
Defaults to ``None``.
auto_archive_duration: :class:`int`
The duration in minutes before a thread is automatically archived for inactivity.
The duration in minutes before a thread is automatically hidden from the channel list.
If not provided, the channel's default auto archive duration is used.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``, if provided.
@ -766,7 +924,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.id,
name=name,
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
type=type.value,
type=type.value, # type: ignore # we're assuming that the user is passing a valid variant
reason=reason,
invitable=invitable,
rate_limit_per_user=slowmode_delay,
@ -878,7 +1036,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
before_timestamp = update_before(threads[-1])
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
class VocalGuildChannel(discord.abc.Messageable, discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
__slots__ = (
'name',
'id',
@ -901,6 +1059,9 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
self.id: int = int(data['id'])
self._update(guild, data)
async def _get_channel(self) -> Self:
return self
def _get_voice_client_key(self) -> Tuple[int, str]:
return self.guild.id, 'guild_id'
@ -988,103 +1149,6 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
base.value &= ~denied.value
return base
class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
"""Represents a Discord guild voice channel.
.. container:: operations
.. describe:: x == y
Checks if two channels are equal.
.. describe:: x != y
Checks if two channels are not equal.
.. describe:: hash(x)
Returns the channel's hash.
.. describe:: str(x)
Returns the channel's name.
Attributes
-----------
name: :class:`str`
The channel name.
guild: :class:`Guild`
The guild the channel belongs to.
id: :class:`int`
The channel ID.
nsfw: :class:`bool`
If the channel is marked as "not safe for work" or "age restricted".
.. versionadded:: 2.0
category_id: Optional[:class:`int`]
The category channel ID this channel belongs to, if applicable.
position: :class:`int`
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
bitrate: :class:`int`
The channel's preferred audio bitrate in bits per second.
user_limit: :class:`int`
The channel's limit for number of members that can be in a voice channel.
rtc_region: Optional[:class:`str`]
The region for the voice channel's voice communication.
A value of ``None`` indicates automatic voice region detection.
.. versionadded:: 1.7
.. versionchanged:: 2.0
The type of this attribute has changed to :class:`str`.
video_quality_mode: :class:`VideoQualityMode`
The camera video quality for the voice channel's participants.
.. versionadded:: 2.0
last_message_id: Optional[:class:`int`]
The last message ID of the message sent to this channel. It may
*not* point to an existing or valid message.
.. versionadded:: 2.0
slowmode_delay: :class:`int`
The number of seconds a member must wait between sending messages
in this channel. A value of ``0`` denotes that it is disabled.
Bots and users with :attr:`~Permissions.manage_channels` or
:attr:`~Permissions.manage_messages` bypass slowmode.
.. versionadded:: 2.2
"""
__slots__ = ()
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('rtc_region', self.rtc_region),
('position', self.position),
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
('category_id', self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
async def _get_channel(self) -> Self:
return self
@property
def _scheduled_event_entity_type(self) -> Optional[EntityType]:
return EntityType.voice
@property
def type(self) -> Literal[ChannelType.voice]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.voice
@property
def last_message(self) -> Optional[Message]:
"""Retrieves the last message from this channel in cache.
@ -1129,7 +1193,7 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
from .message import PartialMessage
return PartialMessage(channel=self, id=message_id)
return PartialMessage(channel=self, id=message_id) # type: ignore # VocalGuildChannel is an impl detail
async def delete_messages(self, messages: Iterable[Snowflake], *, reason: Optional[str] = None) -> None:
"""|coro|
@ -1333,8 +1397,119 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
return Webhook.from_state(data, state=self._state)
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel:
return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason)
async def clone(
self, *, name: Optional[str] = None, category: Optional[CategoryChannel] = None, reason: Optional[str] = None
) -> Self:
base = {
'bitrate': self.bitrate,
'user_limit': self.user_limit,
'rate_limit_per_user': self.slowmode_delay,
'nsfw': self.nsfw,
'video_quality_mode': self.video_quality_mode.value,
}
if self.rtc_region:
base['rtc_region'] = self.rtc_region
return await self._clone_impl(
base,
name=name,
category=category,
reason=reason,
)
class VoiceChannel(VocalGuildChannel):
"""Represents a Discord guild voice channel.
.. container:: operations
.. describe:: x == y
Checks if two channels are equal.
.. describe:: x != y
Checks if two channels are not equal.
.. describe:: hash(x)
Returns the channel's hash.
.. describe:: str(x)
Returns the channel's name.
Attributes
-----------
name: :class:`str`
The channel name.
guild: :class:`Guild`
The guild the channel belongs to.
id: :class:`int`
The channel ID.
nsfw: :class:`bool`
If the channel is marked as "not safe for work" or "age restricted".
.. versionadded:: 2.0
category_id: Optional[:class:`int`]
The category channel ID this channel belongs to, if applicable.
position: :class:`int`
The position in the channel list. This is a number that starts at 0. e.g. the
top channel is position 0.
bitrate: :class:`int`
The channel's preferred audio bitrate in bits per second.
user_limit: :class:`int`
The channel's limit for number of members that can be in a voice channel.
rtc_region: Optional[:class:`str`]
The region for the voice channel's voice communication.
A value of ``None`` indicates automatic voice region detection.
.. versionadded:: 1.7
.. versionchanged:: 2.0
The type of this attribute has changed to :class:`str`.
video_quality_mode: :class:`VideoQualityMode`
The camera video quality for the voice channel's participants.
.. versionadded:: 2.0
last_message_id: Optional[:class:`int`]
The last message ID of the message sent to this channel. It may
*not* point to an existing or valid message.
.. versionadded:: 2.0
slowmode_delay: :class:`int`
The number of seconds a member must wait between sending messages
in this channel. A value of ``0`` denotes that it is disabled.
Bots and users with :attr:`~Permissions.manage_channels` or
:attr:`~Permissions.manage_messages` bypass slowmode.
.. versionadded:: 2.2
"""
__slots__ = ()
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('rtc_region', self.rtc_region),
('position', self.position),
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
('category_id', self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
@property
def _scheduled_event_entity_type(self) -> Optional[EntityType]:
return EntityType.voice
@property
def type(self) -> Literal[ChannelType.voice]:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.voice
@overload
async def edit(self) -> None:
@ -1358,6 +1533,8 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ...,
rtc_region: Optional[str] = ...,
video_quality_mode: VideoQualityMode = ...,
slowmode_delay: int = ...,
status: Optional[str] = ...,
reason: Optional[str] = ...,
) -> VoiceChannel:
...
@ -1417,6 +1594,11 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
The camera video quality for the voice channel's participants.
.. versionadded:: 2.0
status: Optional[:class:`str`]
The new voice channel status. It can be up to 500 characters.
Can be ``None`` to remove the status.
.. versionadded:: 2.4
Raises
------
@ -1438,6 +1620,35 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel):
# the payload will always be the proper channel payload
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
async def send_sound(self, sound: Union[SoundboardSound, SoundboardDefaultSound], /) -> None:
"""|coro|
Sends a soundboard sound for this channel.
You must have :attr:`~Permissions.speak` and :attr:`~Permissions.use_soundboard` to do this.
Additionally, you must have :attr:`~Permissions.use_external_sounds` if the sound is from
a different guild.
.. versionadded:: 2.5
Parameters
-----------
sound: Union[:class:`SoundboardSound`, :class:`SoundboardDefaultSound`]
The sound to send for this channel.
Raises
-------
Forbidden
You do not have permissions to send a sound for this channel.
HTTPException
Sending the sound failed.
"""
payload = {'sound_id': sound.id}
if not isinstance(sound, SoundboardDefaultSound) and self.guild.id != sound.guild.id:
payload['source_guild_id'] = sound.guild.id
await self._state.http.send_soundboard_sound(self.id, **payload)
class StageChannel(VocalGuildChannel):
"""Represents a Discord guild stage channel.
@ -1492,6 +1703,11 @@ class StageChannel(VocalGuildChannel):
The camera video quality for the stage channel's participants.
.. versionadded:: 2.0
last_message_id: Optional[:class:`int`]
The last message ID of the message sent to this channel. It may
*not* point to an existing or valid message.
.. versionadded:: 2.2
slowmode_delay: :class:`int`
The number of seconds a member must wait between sending messages
in this channel. A value of ``0`` denotes that it is disabled.
@ -1565,10 +1781,6 @@ class StageChannel(VocalGuildChannel):
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.stage_voice
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel:
return await self._clone_impl({}, name=name, reason=reason)
@property
def instance(self) -> Optional[StageInstance]:
"""Optional[:class:`StageInstance`]: The running stage instance of the stage channel.
@ -1578,7 +1790,13 @@ class StageChannel(VocalGuildChannel):
return utils.get(self.guild.stage_instances, channel_id=self.id)
async def create_instance(
self, *, topic: str, privacy_level: PrivacyLevel = MISSING, reason: Optional[str] = None
self,
*,
topic: str,
privacy_level: PrivacyLevel = MISSING,
send_start_notification: bool = False,
scheduled_event: Snowflake = MISSING,
reason: Optional[str] = None,
) -> StageInstance:
"""|coro|
@ -1594,6 +1812,15 @@ class StageChannel(VocalGuildChannel):
The stage instance's topic.
privacy_level: :class:`PrivacyLevel`
The stage instance's privacy level. Defaults to :attr:`PrivacyLevel.guild_only`.
send_start_notification: :class:`bool`
Whether to send a start notification. This sends a push notification to @everyone if ``True``. Defaults to ``False``.
You must have :attr:`~Permissions.mention_everyone` to do this.
.. versionadded:: 2.3
scheduled_event: :class:`~discord.abc.Snowflake`
The guild scheduled event associated with the stage instance.
.. versionadded:: 2.4
reason: :class:`str`
The reason the stage instance was created. Shows up on the audit log.
@ -1620,6 +1847,11 @@ class StageChannel(VocalGuildChannel):
payload['privacy_level'] = privacy_level.value
if scheduled_event is not MISSING:
payload['guild_scheduled_event_id'] = scheduled_event.id
payload['send_start_notification'] = send_start_notification
data = await self._state.http.create_stage_instance(**payload, reason=reason)
return StageInstance(guild=self.guild, state=self._state, data=data)
@ -1659,12 +1891,15 @@ class StageChannel(VocalGuildChannel):
*,
name: str = ...,
nsfw: bool = ...,
bitrate: int = ...,
user_limit: int = ...,
position: int = ...,
sync_permissions: int = ...,
category: Optional[CategoryChannel] = ...,
overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ...,
rtc_region: Optional[str] = ...,
video_quality_mode: VideoQualityMode = ...,
slowmode_delay: int = ...,
reason: Optional[str] = ...,
) -> StageChannel:
...
@ -1693,10 +1928,14 @@ class StageChannel(VocalGuildChannel):
----------
name: :class:`str`
The new channel's name.
bitrate: :class:`int`
The new channel's bitrate.
position: :class:`int`
The new channel's position.
nsfw: :class:`bool`
To mark the channel as NSFW or not.
user_limit: :class:`int`
The new channel's user limit.
sync_permissions: :class:`bool`
Whether to sync permissions with the channel's new or pre-existing
category. Defaults to ``False``.
@ -1819,7 +2058,13 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return self.nsfw
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel:
async def clone(
self,
*,
name: Optional[str] = None,
category: Optional[CategoryChannel] = None,
reason: Optional[str] = None,
) -> CategoryChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
@overload
@ -1939,6 +2184,16 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
ret.sort(key=lambda c: (c.position, c.id))
return ret
@property
def forums(self) -> List[ForumChannel]:
"""List[:class:`ForumChannel`]: Returns the forum channels that are under this category.
.. versionadded:: 2.4
"""
r = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, ForumChannel)]
r.sort(key=lambda c: (c.position, c.id))
return r
async def create_text_channel(self, name: str, **options: Any) -> TextChannel:
"""|coro|
@ -2147,6 +2402,10 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
Defaults to :attr:`ForumLayoutType.not_set`.
.. versionadded:: 2.2
default_sort_order: Optional[:class:`ForumOrderType`]
The default sort order for posts in this forum channel.
.. versionadded:: 2.3
"""
__slots__ = (
@ -2156,6 +2415,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
'topic',
'_state',
'_flags',
'_type',
'nsfw',
'category_id',
'position',
@ -2166,13 +2426,15 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
'default_thread_slowmode_delay',
'default_reaction_emoji',
'default_layout',
'default_sort_order',
'_available_tags',
'_flags',
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: ForumChannelPayload):
def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[ForumChannelPayload, MediaChannelPayload]):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self._type: Literal[15, 16] = data['type']
self._update(guild, data)
def __repr__(self) -> str:
@ -2186,7 +2448,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
def _update(self, guild: Guild, data: ForumChannelPayload) -> None:
def _update(self, guild: Guild, data: Union[ForumChannelPayload, MediaChannelPayload]) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
@ -2211,18 +2473,33 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
name=default_reaction_emoji.get('emoji_name') or '',
)
self.default_sort_order: Optional[ForumOrderType] = None
default_sort_order = data.get('default_sort_order')
if default_sort_order is not None:
self.default_sort_order = try_enum(ForumOrderType, default_sort_order)
self._flags: int = data.get('flags', 0)
self._fill_overwrites(data)
@property
def type(self) -> Literal[ChannelType.forum]:
def type(self) -> Literal[ChannelType.forum, ChannelType.media]:
""":class:`ChannelType`: The channel's Discord type."""
if self._type == 16:
return ChannelType.media
return ChannelType.forum
@property
def _sorting_bucket(self) -> int:
return ChannelType.text.value
@property
def members(self) -> List[Member]:
"""List[:class:`Member`]: Returns all members that can see this channel.
.. versionadded:: 2.5
"""
return [m for m in self.guild.members if self.permissions_for(m).read_messages]
@property
def _scheduled_event_entity_type(self) -> Optional[EntityType]:
return None
@ -2304,10 +2581,41 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
""":class:`bool`: Checks if the forum is NSFW."""
return self.nsfw
def is_media(self) -> bool:
""":class:`bool`: Checks if the channel is a media channel.
.. versionadded:: 2.4
"""
return self._type == ChannelType.media.value
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> ForumChannel:
async def clone(
self,
*,
name: Optional[str] = None,
category: Optional[CategoryChannel],
reason: Optional[str] = None,
) -> ForumChannel:
base = {
'topic': self.topic,
'rate_limit_per_user': self.slowmode_delay,
'nsfw': self.nsfw,
'default_auto_archive_duration': self.default_auto_archive_duration,
'available_tags': [tag.to_dict() for tag in self.available_tags],
'default_thread_rate_limit_per_user': self.default_thread_slowmode_delay,
}
if self.default_sort_order:
base['default_sort_order'] = self.default_sort_order.value
if self.default_reaction_emoji:
base['default_reaction_emoji'] = self.default_reaction_emoji._to_forum_tag_payload()
if not self.is_media() and self.default_layout:
base['default_forum_layout'] = self.default_layout.value
return await self._clone_impl(
{'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason
base,
name=name,
category=category,
reason=reason,
)
@overload
@ -2337,6 +2645,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
default_thread_slowmode_delay: int = ...,
default_reaction_emoji: Optional[EmojiInputType] = ...,
default_layout: ForumLayoutType = ...,
default_sort_order: ForumOrderType = ...,
require_tag: bool = ...,
) -> ForumChannel:
...
@ -2395,6 +2704,10 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
The new default layout for posts in this forum.
.. versionadded:: 2.2
default_sort_order: Optional[:class:`ForumOrderType`]
The new default sort order for posts in this forum.
.. versionadded:: 2.3
require_tag: :class:`bool`
Whether to require a tag for threads in this channel or not.
@ -2457,6 +2770,21 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
options['default_forum_layout'] = layout.value
try:
sort_order = options.pop('default_sort_order')
except KeyError:
pass
else:
if sort_order is None:
options['default_sort_order'] = None
else:
if not isinstance(sort_order, ForumOrderType):
raise TypeError(
f'default_sort_order parameter must be a ForumOrderType not {sort_order.__class__.__name__}'
)
options['default_sort_order'] = sort_order.value
payload = await self._edit(options, reason=reason)
if payload is not None:
# the payload will always be the proper channel payload
@ -2548,7 +2876,7 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
name: :class:`str`
The name of the thread.
auto_archive_duration: :class:`int`
The duration in minutes before a thread is automatically archived for inactivity.
The duration in minutes before a thread is automatically hidden from the channel list.
If not provided, the channel's default auto archive duration is used.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``, if provided.
@ -2618,8 +2946,6 @@ class ForumChannel(discord.abc.GuildChannel, Hashable):
raise TypeError(f'view parameter must be View not {view.__class__.__name__}')
if suppress_embeds:
from .message import MessageFlags # circular import
flags = MessageFlags._from_value(4)
else:
flags = MISSING
@ -2818,17 +3144,21 @@ class DMChannel(discord.abc.Messageable, discord.abc.PrivateChannel, Hashable):
The user you are participating with in the direct message channel.
If this channel is received through the gateway, the recipient information
may not be always available.
recipients: List[:class:`User`]
The users you are participating with in the DM channel.
.. versionadded:: 2.4
me: :class:`ClientUser`
The user presenting yourself.
id: :class:`int`
The direct message channel ID.
"""
__slots__ = ('id', 'recipient', 'me', '_state')
__slots__ = ('id', 'recipients', 'me', '_state')
def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
self._state: ConnectionState = state
self.recipient: Optional[User] = state.store_user(data['recipients'][0])
self.recipients: List[User] = [state.store_user(u) for u in data.get('recipients', [])]
self.me: ClientUser = me
self.id: int = int(data['id'])
@ -2848,11 +3178,17 @@ class DMChannel(discord.abc.Messageable, discord.abc.PrivateChannel, Hashable):
self = cls.__new__(cls)
self._state = state
self.id = channel_id
self.recipient = None
self.recipients = []
# state.user won't be None here
self.me = state.user # type: ignore
return self
@property
def recipient(self) -> Optional[User]:
if self.recipients:
return self.recipients[0]
return None
@property
def type(self) -> Literal[ChannelType.private]:
""":class:`ChannelType`: The channel's Discord type."""
@ -3201,6 +3537,14 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
return Permissions.none()
@property
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the channel.
.. versionadded:: 2.5
"""
return f'<#{self.id}>'
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
@ -3237,6 +3581,8 @@ def _guild_channel_factory(channel_type: int):
return StageChannel, value
elif value is ChannelType.forum:
return ForumChannel, value
elif value is ChannelType.media:
return ForumChannel, value
else:
return None, value

878
discord/client.py

File diff suppressed because it is too large

103
discord/colour.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import colorsys
@ -104,6 +105,11 @@ class Colour:
Returns the raw colour value.
.. note::
The colour values in the classmethods are mostly provided as-is and can change between
versions should the Discord client's representation of that colour also change.
Attributes
------------
value: :class:`int`
@ -170,7 +176,7 @@ class Colour:
return cls.from_rgb(*(int(x * 255) for x in rgb))
@classmethod
def from_str(cls, value: str) -> Self:
def from_str(cls, value: str) -> Colour:
"""Constructs a :class:`Colour` from a string.
The following formats are accepted:
@ -191,6 +197,9 @@ class Colour:
The string could not be converted into a colour.
"""
if not value:
raise ValueError('unknown colour format given')
if value[0] == '#':
return parse_hex_number(value[1:])
@ -449,20 +458,59 @@ class Colour:
"""
return cls(0x99AAB5)
@classmethod
def ash_theme(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x2E2E34``.
This will appear transparent on Discord's ash theme.
.. colour:: #2E2E34
.. versionadded:: 2.6
"""
return cls(0x2E2E34)
@classmethod
def dark_theme(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x313338``.
"""A factory method that returns a :class:`Colour` with a value of ``0x1A1A1E``.
This will appear transparent on Discord's dark theme.
.. colour:: #313338
.. colour:: #1A1A1E
.. versionadded:: 1.5
.. versionchanged:: 2.2
Updated colour from previous ``0x36393F`` to reflect discord theme changes.
.. versionchanged:: 2.6
Updated colour from previous ``0x313338`` to reflect discord theme changes.
"""
return cls(0x313338)
return cls(0x1A1A1E)
@classmethod
def onyx_theme(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x070709``.
This will appear transparent on Discord's onyx theme.
.. colour:: #070709
.. versionadded:: 2.6
"""
return cls(0x070709)
@classmethod
def light_theme(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0xFBFBFB``.
This will appear transparent on Discord's light theme.
.. colour:: #FBFBFB
.. versionadded:: 2.6
"""
return cls(0xFBFBFB)
@classmethod
def fuchsia(cls) -> Self:
@ -484,25 +532,62 @@ class Colour:
"""
return cls(0xFEE75C)
@classmethod
def ash_embed(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x37373E``.
.. colour:: #37373E
.. versionadded:: 2.6
"""
return cls(0x37373E)
@classmethod
def dark_embed(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x2B2D31``.
"""A factory method that returns a :class:`Colour` with a value of ``0x242429``.
.. colour:: #2B2D31
.. colour:: #242429
.. versionadded:: 2.2
.. versionchanged:: 2.6
Updated colour from previous ``0x2B2D31`` to reflect discord theme changes.
"""
return cls(0x242429)
@classmethod
def onyx_embed(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x131416``.
.. colour:: #131416
.. versionadded:: 2.6
"""
return cls(0x2B2D31)
return cls(0x131416)
@classmethod
def light_embed(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0xEEEFF1``.
"""A factory method that returns a :class:`Colour` with a value of ``0xFFFFFF``.
.. colour:: #EEEFF1
.. versionadded:: 2.2
.. versionchanged:: 2.6
Updated colour from previous ``0xEEEFF1`` to reflect discord theme changes.
"""
return cls(0xFFFFFF)
@classmethod
def pink(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0xEB459F``.
.. colour:: #EB459F
.. versionadded:: 2.3
"""
return cls(0xEEEFF1)
return cls(0xEB459F)
Color = Colour

155
discord/components.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import ClassVar, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, ChannelType
from .enums import try_enum, ComponentType, ButtonStyle, TextStyle, ChannelType, SelectDefaultValueType
from .utils import get_slots, MISSING
from .partial_emoji import PartialEmoji, _EmojiTag
@ -40,8 +40,10 @@ if TYPE_CHECKING:
ActionRow as ActionRowPayload,
TextInput as TextInputPayload,
ActionRowChildComponent as ActionRowChildComponentPayload,
SelectDefaultValues as SelectDefaultValuesPayload,
)
from .emoji import Emoji
from .abc import Snowflake
ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput']
@ -53,6 +55,7 @@ __all__ = (
'SelectMenu',
'SelectOption',
'TextInput',
'SelectDefaultValue',
)
@ -167,6 +170,10 @@ class Button(Component):
The label of the button, if any.
emoji: Optional[:class:`PartialEmoji`]
The emoji of the button, if available.
sku_id: Optional[:class:`int`]
The SKU ID this button sends you to, if available.
.. versionadded:: 2.4
"""
__slots__: Tuple[str, ...] = (
@ -176,6 +183,7 @@ class Button(Component):
'disabled',
'label',
'emoji',
'sku_id',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
@ -188,10 +196,15 @@ class Button(Component):
self.label: Optional[str] = data.get('label')
self.emoji: Optional[PartialEmoji]
try:
self.emoji = PartialEmoji.from_dict(data['emoji'])
self.emoji = PartialEmoji.from_dict(data['emoji']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.emoji = None
try:
self.sku_id: Optional[int] = int(data['sku_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.sku_id = None
@property
def type(self) -> Literal[ComponentType.button]:
""":class:`ComponentType`: The type of component."""
@ -204,6 +217,9 @@ class Button(Component):
'disabled': self.disabled,
}
if self.sku_id:
payload['sku_id'] = str(self.sku_id)
if self.label:
payload['label'] = self.label
@ -263,6 +279,7 @@ class SelectMenu(Component):
'options',
'disabled',
'channel_types',
'default_values',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
@ -276,10 +293,13 @@ class SelectMenu(Component):
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False)
self.channel_types: List[ChannelType] = [try_enum(ChannelType, t) for t in data.get('channel_types', [])]
self.default_values: List[SelectDefaultValue] = [
SelectDefaultValue.from_dict(d) for d in data.get('default_values', [])
]
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
'type': self.type.value,
'type': self.type.value, # type: ignore # we know this is a select menu.
'custom_id': self.custom_id,
'min_values': self.min_values,
'max_values': self.max_values,
@ -291,6 +311,8 @@ class SelectMenu(Component):
payload['options'] = [op.to_dict() for op in self.options]
if self.channel_types:
payload['channel_types'] = [t.value for t in self.channel_types]
if self.default_values:
payload["default_values"] = [v.to_dict() for v in self.default_values]
return payload
@ -309,8 +331,8 @@ class SelectOption:
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
If not provided when constructed then it defaults to the label.
Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
@ -323,14 +345,12 @@ class SelectOption:
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
label.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
default: :class:`bool`
Whether this option is selected by default.
"""
@ -395,7 +415,7 @@ class SelectOption:
@classmethod
def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
try:
emoji = PartialEmoji.from_dict(data['emoji'])
emoji = PartialEmoji.from_dict(data['emoji']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
emoji = None
@ -422,6 +442,9 @@ class SelectOption:
return payload
def copy(self) -> SelectOption:
return self.__class__.from_dict(self.to_dict())
class TextInput(Component):
"""Represents a text input from the Discord Bot UI Kit.
@ -512,6 +535,116 @@ class TextInput(Component):
return self.value
class SelectDefaultValue:
"""Represents a select menu's default value.
These can be created by users.
.. versionadded:: 2.4
Parameters
-----------
id: :class:`int`
The id of a role, user, or channel.
type: :class:`SelectDefaultValueType`
The type of value that ``id`` represents.
"""
def __init__(
self,
*,
id: int,
type: SelectDefaultValueType,
) -> None:
self.id: int = id
self._type: SelectDefaultValueType = type
@property
def type(self) -> SelectDefaultValueType:
""":class:`SelectDefaultValueType`: The type of value that ``id`` represents."""
return self._type
@type.setter
def type(self, value: SelectDefaultValueType) -> None:
if not isinstance(value, SelectDefaultValueType):
raise TypeError(f'expected SelectDefaultValueType, received {value.__class__.__name__} instead')
self._type = value
def __repr__(self) -> str:
return f'<SelectDefaultValue id={self.id!r} type={self.type!r}>'
@classmethod
def from_dict(cls, data: SelectDefaultValuesPayload) -> SelectDefaultValue:
return cls(
id=data['id'],
type=try_enum(SelectDefaultValueType, data['type']),
)
def to_dict(self) -> SelectDefaultValuesPayload:
return {
'id': self.id,
'type': self._type.value,
}
@classmethod
def from_channel(cls, channel: Snowflake, /) -> Self:
"""Creates a :class:`SelectDefaultValue` with the type set to :attr:`~SelectDefaultValueType.channel`.
Parameters
-----------
channel: :class:`~discord.abc.Snowflake`
The channel to create the default value for.
Returns
--------
:class:`SelectDefaultValue`
The default value created with the channel.
"""
return cls(
id=channel.id,
type=SelectDefaultValueType.channel,
)
@classmethod
def from_role(cls, role: Snowflake, /) -> Self:
"""Creates a :class:`SelectDefaultValue` with the type set to :attr:`~SelectDefaultValueType.role`.
Parameters
-----------
role: :class:`~discord.abc.Snowflake`
The role to create the default value for.
Returns
--------
:class:`SelectDefaultValue`
The default value created with the role.
"""
return cls(
id=role.id,
type=SelectDefaultValueType.role,
)
@classmethod
def from_user(cls, user: Snowflake, /) -> Self:
"""Creates a :class:`SelectDefaultValue` with the type set to :attr:`~SelectDefaultValueType.user`.
Parameters
-----------
user: :class:`~discord.abc.Snowflake`
The user to create the default value for.
Returns
--------
:class:`SelectDefaultValue`
The default value created with the user.
"""
return cls(
id=user.id,
type=SelectDefaultValueType.user,
)
@overload
def _component_factory(data: ActionRowChildComponentPayload) -> Optional[ActionRowChildComponentType]:
...
@ -527,7 +660,7 @@ def _component_factory(data: ComponentPayload) -> Optional[Union[ActionRow, Acti
return ActionRow(data)
elif data['type'] == 2:
return Button(data)
elif data['type'] == 3:
return SelectMenu(data)
elif data['type'] == 4:
return TextInput(data)
elif data['type'] in (3, 5, 6, 7, 8):
return SelectMenu(data)

80
discord/embeds.py

@ -29,6 +29,7 @@ from typing import Any, Dict, List, Mapping, Optional, Protocol, TYPE_CHECKING,
from . import utils
from .colour import Colour
from .flags import AttachmentFlags, EmbedFlags
# fmt: off
__all__ = (
@ -45,7 +46,7 @@ class EmbedProxy:
return len(self.__dict__)
def __repr__(self) -> str:
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
inner = ', '.join((f'{k}={getattr(self, k)!r}' for k in dir(self) if not k.startswith('_')))
return f'EmbedProxy({inner})'
def __getattr__(self, attr: str) -> None:
@ -55,6 +56,22 @@ class EmbedProxy:
return isinstance(other, EmbedProxy) and self.__dict__ == other.__dict__
class EmbedMediaProxy(EmbedProxy):
def __init__(self, layer: Dict[str, Any]):
super().__init__(layer)
self._flags = self.__dict__.pop('flags', 0)
def __bool__(self) -> bool:
# This is a nasty check to see if we only have the `_flags` attribute which is created regardless in init.
# Had we had any of the other items, like image/video data this would be >1 and therefor
# would not be "empty".
return len(self.__dict__) > 1
@property
def flags(self) -> AttachmentFlags:
return AttachmentFlags._from_value(self._flags or 0)
if TYPE_CHECKING:
from typing_extensions import Self
@ -76,11 +93,7 @@ if TYPE_CHECKING:
proxy_url: Optional[str]
height: Optional[int]
width: Optional[int]
class _EmbedVideoProxy(Protocol):
url: Optional[str]
height: Optional[int]
width: Optional[int]
flags: AttachmentFlags
class _EmbedProviderProxy(Protocol):
name: Optional[str]
@ -131,7 +144,7 @@ class Embed:
The type of embed. Usually "rich".
This can be set during initialisation.
Possible strings for embed types can be found on discord's
:ddocs:`api docs <resources/channel#embed-object-embed-types>`
:ddocs:`api docs <resources/message#embed-object-embed-types>`
description: Optional[:class:`str`]
The description of the embed.
This can be set during initialisation.
@ -162,6 +175,7 @@ class Embed:
'_author',
'_fields',
'description',
'_flags',
)
def __init__(
@ -181,6 +195,7 @@ class Embed:
self.type: EmbedType = type
self.url: Optional[str] = url
self.description: Optional[str] = description
self._flags: int = 0
if self.title is not None:
self.title = str(self.title)
@ -199,7 +214,7 @@ class Embed:
"""Converts a :class:`dict` to a :class:`Embed` provided it is in the
format that Discord expects it to be in.
You can find out about this format in the :ddocs:`official Discord documentation <resources/channel#embed-object>`.
You can find out about this format in the :ddocs:`official Discord documentation <resources/message#embed-object>`.
Parameters
-----------
@ -215,6 +230,7 @@ class Embed:
self.type = data.get('type', None)
self.description = data.get('description', None)
self.url = data.get('url', None)
self._flags = data.get('flags', 0)
if self.title is not None:
self.title = str(self.title)
@ -305,8 +321,17 @@ class Embed:
and self.image == other.image
and self.provider == other.provider
and self.video == other.video
and self._flags == other._flags
)
@property
def flags(self) -> EmbedFlags:
""":class:`EmbedFlags`: The flags of this embed.
.. versionadded:: 2.5
"""
return EmbedFlags._from_value(self._flags or 0)
@property
def colour(self) -> Optional[Colour]:
return getattr(self, '_colour', None)
@ -395,15 +420,16 @@ class Embed:
Possible attributes you can access are:
- ``url``
- ``proxy_url``
- ``width``
- ``height``
- ``url`` for the image URL.
- ``proxy_url`` for the proxied image URL.
- ``width`` for the image width.
- ``height`` for the image height.
- ``flags`` for the image's attachment flags.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
return EmbedMediaProxy(getattr(self, '_image', {})) # type: ignore
def set_image(self, *, url: Optional[Any]) -> Self:
"""Sets the image for the embed content.
@ -413,8 +439,9 @@ class Embed:
Parameters
-----------
url: :class:`str`
url: Optional[:class:`str`]
The source URL for the image. Only HTTP(S) is supported.
If ``None`` is passed, any existing image is removed.
Inline attachment URLs are also supported, see :ref:`local_image`.
"""
@ -436,15 +463,16 @@ class Embed:
Possible attributes you can access are:
- ``url``
- ``proxy_url``
- ``width``
- ``height``
- ``url`` for the thumbnail URL.
- ``proxy_url`` for the proxied thumbnail URL.
- ``width`` for the thumbnail width.
- ``height`` for the thumbnail height.
- ``flags`` for the thumbnail's attachment flags.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
return EmbedMediaProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self, *, url: Optional[Any]) -> Self:
"""Sets the thumbnail for the embed content.
@ -452,13 +480,11 @@ class Embed:
This function returns the class instance to allow for fluent-style
chaining.
.. versionchanged:: 1.4
Passing ``None`` removes the thumbnail.
Parameters
-----------
url: :class:`str`
url: Optional[:class:`str`]
The source URL for the thumbnail. Only HTTP(S) is supported.
If ``None`` is passed, any existing thumbnail is removed.
Inline attachment URLs are also supported, see :ref:`local_image`.
"""
@ -475,19 +501,21 @@ class Embed:
return self
@property
def video(self) -> _EmbedVideoProxy:
def video(self) -> _EmbedMediaProxy:
"""Returns an ``EmbedProxy`` denoting the video contents.
Possible attributes include:
- ``url`` for the video URL.
- ``proxy_url`` for the proxied video URL.
- ``height`` for the video height.
- ``width`` for the video width.
- ``flags`` for the video's attachment flags.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_video', {})) # type: ignore
return EmbedMediaProxy(getattr(self, '_video', {})) # type: ignore
@property
def provider(self) -> _EmbedProviderProxy:
@ -715,7 +743,7 @@ class Embed:
# fmt: off
result = {
key[1:]: getattr(self, key)
for key in self.__slots__
for key in Embed.__slots__
if key[0] == '_' and hasattr(self, key)
}
# fmt: on

49
discord/emoji.py

@ -29,6 +29,8 @@ from .asset import Asset, AssetMixin
from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User
from .errors import MissingApplicationID
from .object import Object
# fmt: off
__all__ = (
@ -93,6 +95,10 @@ class Emoji(_EmojiTag, AssetMixin):
user: Optional[:class:`User`]
The user that created the emoji. This can only be retrieved using :meth:`Guild.fetch_emoji` and
having :attr:`~Permissions.manage_emojis`.
Or if :meth:`.is_application_owned` is ``True``, this is the team member that uploaded
the emoji, or the bot user if it was uploaded using the API and this can
only be retrieved using :meth:`~discord.Client.fetch_application_emoji` or :meth:`~discord.Client.fetch_application_emojis`.
"""
__slots__: Tuple[str, ...] = (
@ -108,7 +114,7 @@ class Emoji(_EmojiTag, AssetMixin):
'available',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload) -> None:
def __init__(self, *, guild: Snowflake, state: ConnectionState, data: EmojiPayload) -> None:
self.guild_id: int = guild.id
self._state: ConnectionState = state
self._from_data(data)
@ -196,20 +202,32 @@ class Emoji(_EmojiTag, AssetMixin):
Deletes the custom emoji.
You must have :attr:`~Permissions.manage_emojis` to do this.
You must have :attr:`~Permissions.manage_emojis` to do this if
:meth:`.is_application_owned` is ``False``.
Parameters
-----------
reason: Optional[:class:`str`]
The reason for deleting this emoji. Shows up on the audit log.
This does not apply if :meth:`.is_application_owned` is ``True``.
Raises
-------
Forbidden
You are not allowed to delete emojis.
HTTPException
An error occurred deleting the emoji.
MissingApplicationID
The emoji is owned by an application but the application ID is missing.
"""
if self.is_application_owned():
application_id = self._state.application_id
if application_id is None:
raise MissingApplicationID
await self._state.http.delete_application_emoji(application_id, self.id)
return
await self._state.http.delete_custom_emoji(self.guild_id, self.id, reason=reason)
@ -231,15 +249,22 @@ class Emoji(_EmojiTag, AssetMixin):
The new emoji name.
roles: List[:class:`~discord.abc.Snowflake`]
A list of roles that can use this emoji. An empty list can be passed to make it available to everyone.
This does not apply if :meth:`.is_application_owned` is ``True``.
reason: Optional[:class:`str`]
The reason for editing this emoji. Shows up on the audit log.
This does not apply if :meth:`.is_application_owned` is ``True``.
Raises
-------
Forbidden
You are not allowed to edit emojis.
HTTPException
An error occurred editing the emoji.
MissingApplicationID
The emoji is owned by an application but the application ID is missing
Returns
--------
@ -253,5 +278,25 @@ class Emoji(_EmojiTag, AssetMixin):
if roles is not MISSING:
payload['roles'] = [role.id for role in roles]
if self.is_application_owned():
application_id = self._state.application_id
if application_id is None:
raise MissingApplicationID
payload.pop('roles', None)
data = await self._state.http.edit_application_emoji(
application_id,
self.id,
payload=payload,
)
return Emoji(guild=Object(0), data=data, state=self._state)
data = await self._state.http.edit_custom_emoji(self.guild_id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore # if guild is None, the http request would have failed
def is_application_owned(self) -> bool:
""":class:`bool`: Whether the emoji is owned by an application.
.. versionadded:: 2.5
"""
return self.guild_id == 0

362
discord/enums.py

@ -42,6 +42,7 @@ __all__ = (
'ActivityType',
'NotificationLevel',
'TeamMembershipState',
'TeamMemberRole',
'WebhookType',
'ExpireBehaviour',
'ExpireBehavior',
@ -67,23 +68,29 @@ __all__ = (
'AutoModRuleEventType',
'AutoModRuleActionType',
'ForumLayoutType',
'ForumOrderType',
'SelectDefaultValueType',
'SKUType',
'EntitlementType',
'EntitlementOwnerType',
'PollLayoutType',
'VoiceChannelEffectAnimationType',
'SubscriptionStatus',
'MessageReferenceType',
)
if TYPE_CHECKING:
from typing_extensions import Self
def _create_value_cls(name: str, comparable: bool):
# All the type ignores here are due to the type checker being unable to recognise
# Runtime type creation without exploding.
cls = namedtuple('_EnumValue_' + name, 'name value')
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' # type: ignore
cls.__str__ = lambda self: f'{name}.{self.name}' # type: ignore
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
cls.__str__ = lambda self: f'{name}.{self.name}'
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value # type: ignore
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value # type: ignore
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value # type: ignore
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value # type: ignore
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls
@ -98,7 +105,14 @@ class EnumMeta(type):
_enum_member_map_: ClassVar[Dict[str, Any]]
_enum_value_map_: ClassVar[Dict[Any, Any]]
def __new__(cls, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any], *, comparable: bool = False) -> Self:
def __new__(
cls,
name: str,
bases: Tuple[type, ...],
attrs: Dict[str, Any],
*,
comparable: bool = False,
) -> EnumMeta:
value_mapping = {}
member_mapping = {}
member_names = []
@ -201,11 +215,18 @@ class ChannelType(Enum):
private_thread = 12
stage_voice = 13
forum = 15
media = 16
def __str__(self) -> str:
return self.name
class MessageReferenceType(Enum):
default = 0
reply = 0
forward = 1
class MessageType(Enum):
default = 0
recipient_add = 1
@ -234,12 +255,18 @@ class MessageType(Enum):
auto_moderation_action = 24
role_subscription_purchase = 25
interaction_premium_upsell = 26
# stage_start = 27
# stage_end = 28
# stage_speaker = 29
# stage_raise_hand = 30
# stage_topic = 31
stage_start = 27
stage_end = 28
stage_speaker = 29
stage_raise_hand = 30
stage_topic = 31
guild_application_premium_subscription = 32
guild_incident_alert_mode_enabled = 36
guild_incident_alert_mode_disabled = 37
guild_incident_report_raid = 38
guild_incident_report_false_alarm = 39
purchase_notification = 44
poll_result = 46
class SpeakingState(Enum):
@ -294,6 +321,7 @@ class DefaultAvatar(Enum):
green = 2
orange = 3
red = 4
pink = 5
def __str__(self) -> str:
return self.name
@ -312,120 +340,130 @@ class AuditLogActionCategory(Enum):
class AuditLogAction(Enum):
# fmt: off
guild_update = 1
channel_create = 10
channel_update = 11
channel_delete = 12
overwrite_create = 13
overwrite_update = 14
overwrite_delete = 15
kick = 20
member_prune = 21
ban = 22
unban = 23
member_update = 24
member_role_update = 25
member_move = 26
member_disconnect = 27
bot_add = 28
role_create = 30
role_update = 31
role_delete = 32
invite_create = 40
invite_update = 41
invite_delete = 42
webhook_create = 50
webhook_update = 51
webhook_delete = 52
emoji_create = 60
emoji_update = 61
emoji_delete = 62
message_delete = 72
message_bulk_delete = 73
message_pin = 74
message_unpin = 75
integration_create = 80
integration_update = 81
integration_delete = 82
stage_instance_create = 83
stage_instance_update = 84
stage_instance_delete = 85
sticker_create = 90
sticker_update = 91
sticker_delete = 92
scheduled_event_create = 100
scheduled_event_update = 101
scheduled_event_delete = 102
thread_create = 110
thread_update = 111
thread_delete = 112
app_command_permission_update = 121
automod_rule_create = 140
automod_rule_update = 141
automod_rule_delete = 142
automod_block_message = 143
automod_flag_message = 144
automod_timeout_member = 145
guild_update = 1
channel_create = 10
channel_update = 11
channel_delete = 12
overwrite_create = 13
overwrite_update = 14
overwrite_delete = 15
kick = 20
member_prune = 21
ban = 22
unban = 23
member_update = 24
member_role_update = 25
member_move = 26
member_disconnect = 27
bot_add = 28
role_create = 30
role_update = 31
role_delete = 32
invite_create = 40
invite_update = 41
invite_delete = 42
webhook_create = 50
webhook_update = 51
webhook_delete = 52
emoji_create = 60
emoji_update = 61
emoji_delete = 62
message_delete = 72
message_bulk_delete = 73
message_pin = 74
message_unpin = 75
integration_create = 80
integration_update = 81
integration_delete = 82
stage_instance_create = 83
stage_instance_update = 84
stage_instance_delete = 85
sticker_create = 90
sticker_update = 91
sticker_delete = 92
scheduled_event_create = 100
scheduled_event_update = 101
scheduled_event_delete = 102
thread_create = 110
thread_update = 111
thread_delete = 112
app_command_permission_update = 121
soundboard_sound_create = 130
soundboard_sound_update = 131
soundboard_sound_delete = 132
automod_rule_create = 140
automod_rule_update = 141
automod_rule_delete = 142
automod_block_message = 143
automod_flag_message = 144
automod_timeout_member = 145
creator_monetization_request_created = 150
creator_monetization_terms_accepted = 151
# fmt: on
@property
def category(self) -> Optional[AuditLogActionCategory]:
# fmt: off
lookup: Dict[AuditLogAction, Optional[AuditLogActionCategory]] = {
AuditLogAction.guild_update: AuditLogActionCategory.update,
AuditLogAction.channel_create: AuditLogActionCategory.create,
AuditLogAction.channel_update: AuditLogActionCategory.update,
AuditLogAction.channel_delete: AuditLogActionCategory.delete,
AuditLogAction.overwrite_create: AuditLogActionCategory.create,
AuditLogAction.overwrite_update: AuditLogActionCategory.update,
AuditLogAction.overwrite_delete: AuditLogActionCategory.delete,
AuditLogAction.kick: None,
AuditLogAction.member_prune: None,
AuditLogAction.ban: None,
AuditLogAction.unban: None,
AuditLogAction.member_update: AuditLogActionCategory.update,
AuditLogAction.member_role_update: AuditLogActionCategory.update,
AuditLogAction.member_move: None,
AuditLogAction.member_disconnect: None,
AuditLogAction.bot_add: None,
AuditLogAction.role_create: AuditLogActionCategory.create,
AuditLogAction.role_update: AuditLogActionCategory.update,
AuditLogAction.role_delete: AuditLogActionCategory.delete,
AuditLogAction.invite_create: AuditLogActionCategory.create,
AuditLogAction.invite_update: AuditLogActionCategory.update,
AuditLogAction.invite_delete: AuditLogActionCategory.delete,
AuditLogAction.webhook_create: AuditLogActionCategory.create,
AuditLogAction.webhook_update: AuditLogActionCategory.update,
AuditLogAction.webhook_delete: AuditLogActionCategory.delete,
AuditLogAction.emoji_create: AuditLogActionCategory.create,
AuditLogAction.emoji_update: AuditLogActionCategory.update,
AuditLogAction.emoji_delete: AuditLogActionCategory.delete,
AuditLogAction.message_delete: AuditLogActionCategory.delete,
AuditLogAction.message_bulk_delete: AuditLogActionCategory.delete,
AuditLogAction.message_pin: None,
AuditLogAction.message_unpin: None,
AuditLogAction.integration_create: AuditLogActionCategory.create,
AuditLogAction.integration_update: AuditLogActionCategory.update,
AuditLogAction.integration_delete: AuditLogActionCategory.delete,
AuditLogAction.stage_instance_create: AuditLogActionCategory.create,
AuditLogAction.stage_instance_update: AuditLogActionCategory.update,
AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete,
AuditLogAction.sticker_create: AuditLogActionCategory.create,
AuditLogAction.sticker_update: AuditLogActionCategory.update,
AuditLogAction.sticker_delete: AuditLogActionCategory.delete,
AuditLogAction.scheduled_event_create: AuditLogActionCategory.create,
AuditLogAction.scheduled_event_update: AuditLogActionCategory.update,
AuditLogAction.scheduled_event_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_create: AuditLogActionCategory.create,
AuditLogAction.thread_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_update: AuditLogActionCategory.update,
AuditLogAction.app_command_permission_update: AuditLogActionCategory.update,
AuditLogAction.automod_rule_create: AuditLogActionCategory.create,
AuditLogAction.automod_rule_update: AuditLogActionCategory.update,
AuditLogAction.automod_rule_delete: AuditLogActionCategory.delete,
AuditLogAction.automod_block_message: None,
AuditLogAction.automod_flag_message: None,
AuditLogAction.automod_timeout_member: None,
AuditLogAction.guild_update: AuditLogActionCategory.update,
AuditLogAction.channel_create: AuditLogActionCategory.create,
AuditLogAction.channel_update: AuditLogActionCategory.update,
AuditLogAction.channel_delete: AuditLogActionCategory.delete,
AuditLogAction.overwrite_create: AuditLogActionCategory.create,
AuditLogAction.overwrite_update: AuditLogActionCategory.update,
AuditLogAction.overwrite_delete: AuditLogActionCategory.delete,
AuditLogAction.kick: None,
AuditLogAction.member_prune: None,
AuditLogAction.ban: None,
AuditLogAction.unban: None,
AuditLogAction.member_update: AuditLogActionCategory.update,
AuditLogAction.member_role_update: AuditLogActionCategory.update,
AuditLogAction.member_move: None,
AuditLogAction.member_disconnect: None,
AuditLogAction.bot_add: None,
AuditLogAction.role_create: AuditLogActionCategory.create,
AuditLogAction.role_update: AuditLogActionCategory.update,
AuditLogAction.role_delete: AuditLogActionCategory.delete,
AuditLogAction.invite_create: AuditLogActionCategory.create,
AuditLogAction.invite_update: AuditLogActionCategory.update,
AuditLogAction.invite_delete: AuditLogActionCategory.delete,
AuditLogAction.webhook_create: AuditLogActionCategory.create,
AuditLogAction.webhook_update: AuditLogActionCategory.update,
AuditLogAction.webhook_delete: AuditLogActionCategory.delete,
AuditLogAction.emoji_create: AuditLogActionCategory.create,
AuditLogAction.emoji_update: AuditLogActionCategory.update,
AuditLogAction.emoji_delete: AuditLogActionCategory.delete,
AuditLogAction.message_delete: AuditLogActionCategory.delete,
AuditLogAction.message_bulk_delete: AuditLogActionCategory.delete,
AuditLogAction.message_pin: None,
AuditLogAction.message_unpin: None,
AuditLogAction.integration_create: AuditLogActionCategory.create,
AuditLogAction.integration_update: AuditLogActionCategory.update,
AuditLogAction.integration_delete: AuditLogActionCategory.delete,
AuditLogAction.stage_instance_create: AuditLogActionCategory.create,
AuditLogAction.stage_instance_update: AuditLogActionCategory.update,
AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete,
AuditLogAction.sticker_create: AuditLogActionCategory.create,
AuditLogAction.sticker_update: AuditLogActionCategory.update,
AuditLogAction.sticker_delete: AuditLogActionCategory.delete,
AuditLogAction.scheduled_event_create: AuditLogActionCategory.create,
AuditLogAction.scheduled_event_update: AuditLogActionCategory.update,
AuditLogAction.scheduled_event_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_create: AuditLogActionCategory.create,
AuditLogAction.thread_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_update: AuditLogActionCategory.update,
AuditLogAction.app_command_permission_update: AuditLogActionCategory.update,
AuditLogAction.automod_rule_create: AuditLogActionCategory.create,
AuditLogAction.automod_rule_update: AuditLogActionCategory.update,
AuditLogAction.automod_rule_delete: AuditLogActionCategory.delete,
AuditLogAction.automod_block_message: None,
AuditLogAction.automod_flag_message: None,
AuditLogAction.automod_timeout_member: None,
AuditLogAction.creator_monetization_request_created: None,
AuditLogAction.creator_monetization_terms_accepted: None,
AuditLogAction.soundboard_sound_create: AuditLogActionCategory.create,
AuditLogAction.soundboard_sound_update: AuditLogActionCategory.update,
AuditLogAction.soundboard_sound_delete: AuditLogActionCategory.delete,
}
# fmt: on
return lookup[self]
@ -465,10 +503,12 @@ class AuditLogAction(Enum):
return 'thread'
elif v < 122:
return 'integration_or_app_command'
elif v < 143:
elif 139 < v < 143:
return 'auto_moderation'
elif v < 146:
return 'user'
elif v < 152:
return 'creator_monetization'
class UserFlags(Enum):
@ -512,6 +552,12 @@ class TeamMembershipState(Enum):
accepted = 2
class TeamMemberRole(Enum):
admin = 'admin'
developer = 'developer'
read_only = 'read_only'
class WebhookType(Enum):
incoming = 1
channel_follower = 2
@ -574,6 +620,8 @@ class InteractionResponseType(Enum):
message_update = 7 # for components
autocomplete_result = 8
modal = 9 # for modals
# premium_required = 10 (deprecated)
launch_activity = 12
class VideoQualityMode(Enum):
@ -605,6 +653,7 @@ class ButtonStyle(Enum):
success = 3
danger = 4
link = 5
premium = 6
# Aliases
blurple = 1
@ -665,6 +714,7 @@ class Locale(Enum):
italian = 'it'
japanese = 'ja'
korean = 'ko'
latin_american_spanish = 'es-419'
lithuanian = 'lt'
norwegian = 'no'
polish = 'pl'
@ -733,16 +783,19 @@ class AutoModRuleTriggerType(Enum):
spam = 3
keyword_preset = 4
mention_spam = 5
member_profile = 6
class AutoModRuleEventType(Enum):
message_send = 1
member_update = 2
class AutoModRuleActionType(Enum):
block_message = 1
send_alert_message = 2
timeout = 3
block_member_interactions = 4
class ForumLayoutType(Enum):
@ -751,9 +804,64 @@ class ForumLayoutType(Enum):
gallery_view = 2
class OnboardingPromptType(Enum):
multiple_choice = 0
dropdown = 1
class ForumOrderType(Enum):
latest_activity = 0
creation_date = 1
class SelectDefaultValueType(Enum):
user = 'user'
role = 'role'
channel = 'channel'
class SKUType(Enum):
durable = 2
consumable = 3
subscription = 5
subscription_group = 6
class EntitlementType(Enum):
purchase = 1
premium_subscription = 2
developer_gift = 3
test_mode_purchase = 4
free_purchase = 5
user_gift = 6
premium_purchase = 7
application_subscription = 8
class EntitlementOwnerType(Enum):
guild = 1
user = 2
class PollLayoutType(Enum):
default = 1
class InviteType(Enum):
guild = 0
group_dm = 1
friend = 2
class ReactionType(Enum):
normal = 0
burst = 1
class VoiceChannelEffectAnimationType(Enum):
premium = 0
basic = 1
class SubscriptionStatus(Enum):
active = 0
ending = 1
inactive = 2
def create_unknown_value(cls: Type[E], val: Any) -> E:

25
discord/errors.py

@ -47,6 +47,12 @@ __all__ = (
'ConnectionClosed',
'PrivilegedIntentsRequired',
'InteractionResponded',
'MissingApplicationID',
)
APP_ID_NOT_FOUND = (
'Client does not have an application_id set. Either the function was called before on_ready '
'was called or application_id was not passed to the Client constructor.'
)
@ -278,3 +284,22 @@ class InteractionResponded(ClientException):
def __init__(self, interaction: Interaction):
self.interaction: Interaction = interaction
super().__init__('This interaction has already been responded to before')
class MissingApplicationID(ClientException):
"""An exception raised when the client does not have an application ID set.
An application ID is required for syncing application commands and various
other application tasks such as SKUs or application emojis.
This inherits from :exc:`~discord.app_commands.AppCommandError`
and :class:`~discord.ClientException`.
.. versionadded:: 2.0
.. versionchanged:: 2.5
This is now exported to the ``discord`` namespace and now inherits from :class:`~discord.ClientException`.
"""
def __init__(self, message: Optional[str] = None):
super().__init__(message or APP_ID_NOT_FOUND)

40
discord/ext/commands/bot.py

@ -166,14 +166,21 @@ class BotBase(GroupMixin[None]):
help_command: Optional[HelpCommand] = _default,
tree_cls: Type[app_commands.CommandTree[Any]] = app_commands.CommandTree,
description: Optional[str] = None,
allowed_contexts: app_commands.AppCommandContext = MISSING,
allowed_installs: app_commands.AppInstallationType = MISSING,
intents: discord.Intents,
**options: Any,
) -> None:
super().__init__(intents=intents, **options)
self.command_prefix: PrefixType[BotT] = command_prefix
self.command_prefix: PrefixType[BotT] = command_prefix # type: ignore
self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore
if allowed_contexts is not MISSING:
self.__tree.allowed_contexts = allowed_contexts
if allowed_installs is not MISSING:
self.__tree.allowed_installs = allowed_installs
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[UserCheck] = []
@ -480,7 +487,7 @@ class BotBase(GroupMixin[None]):
if len(data) == 0:
return True
return await discord.utils.async_all(f(ctx) for f in data)
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user: User, /) -> bool:
"""|coro|
@ -499,6 +506,12 @@ class BotBase(GroupMixin[None]):
``user`` parameter is now positional-only.
.. versionchanged:: 2.4
This function now respects the team member roles if the bot is team-owned.
In order to be considered an owner, they must be either an admin or
a developer.
Parameters
-----------
user: :class:`.abc.User`
@ -515,10 +528,13 @@ class BotBase(GroupMixin[None]):
elif self.owner_ids:
return user.id in self.owner_ids
else:
app = await self.application_info() # type: ignore
app: discord.AppInfo = await self.application_info() # type: ignore
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
self.owner_ids = ids = {
m.id
for m in app.team.members
if m.role in (discord.TeamMemberRole.admin, discord.TeamMemberRole.developer)
}
return user.id in ids
else:
self.owner_id = owner_id = app.owner.id
@ -1479,6 +1495,20 @@ class Bot(BotBase, discord.Client):
The type of application command tree to use. Defaults to :class:`~discord.app_commands.CommandTree`.
.. versionadded:: 2.0
allowed_contexts: :class:`~discord.app_commands.AppCommandContext`
The default allowed contexts that applies to all application commands
in the application command tree.
Note that you can override this on a per command basis.
.. versionadded:: 2.4
allowed_installs: :class:`~discord.app_commands.AppInstallationType`
The default allowed install locations that apply to all application commands
in the application command tree.
Note that you can override this on a per command basis.
.. versionadded:: 2.4
"""
pass

45
discord/ext/commands/cog.py

@ -25,6 +25,7 @@ from __future__ import annotations
import inspect
import discord
import logging
from discord import app_commands
from discord.utils import maybe_coroutine, _to_kebab_case
@ -50,6 +51,7 @@ from ._types import _BaseCommand, BotT
if TYPE_CHECKING:
from typing_extensions import Self
from discord.abc import Snowflake
from discord._types import ClientT
from .bot import BotBase
from .context import Context
@ -64,6 +66,7 @@ __all__ = (
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
_log = logging.getLogger(__name__)
class CogMeta(type):
@ -166,7 +169,7 @@ class CogMeta(type):
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
def __new__(cls, *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
if any(issubclass(base, app_commands.Group) for base in bases):
raise TypeError(
@ -304,6 +307,7 @@ class Cog(metaclass=CogMeta):
# Register the application commands
children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = []
app_command_refs: Dict[str, Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = {}
if cls.__cog_is_app_commands_group__:
group = app_commands.Group(
@ -314,6 +318,8 @@ class Cog(metaclass=CogMeta):
parent=None,
guild_ids=getattr(cls, '__discord_app_commands_default_guilds__', None),
guild_only=getattr(cls, '__discord_app_commands_guild_only__', False),
allowed_contexts=getattr(cls, '__discord_app_commands_contexts__', None),
allowed_installs=getattr(cls, '__discord_app_commands_installation_types__', None),
default_permissions=getattr(cls, '__discord_app_commands_default_permissions__', None),
extras=cls.__cog_group_extras__,
)
@ -330,6 +336,16 @@ class Cog(metaclass=CogMeta):
# Get the latest parent reference
parent = lookup[parent.qualified_name] # type: ignore
# Hybrid commands already deal with updating the reference
# Due to the copy below, so we need to handle them specially
if hasattr(parent, '__commands_is_hybrid__') and hasattr(command, '__commands_is_hybrid__'):
current: Optional[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = getattr(
command, 'app_command', None
)
updated = app_command_refs.get(command.qualified_name)
if current and updated:
command.app_command = updated # type: ignore # Safe attribute access
# Update our parent's reference to our self
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
@ -344,8 +360,15 @@ class Cog(metaclass=CogMeta):
# The type checker does not see the app_command attribute even though it exists
command.app_command = app_command # type: ignore
# Update all the references to point to the new copy
if isinstance(app_command, app_commands.Group):
for child in app_command.walk_commands():
app_command_refs[child.qualified_name] = child
if hasattr(child, '__commands_is_hybrid_app_command__') and child.qualified_name in lookup:
child.wrapped = lookup[child.qualified_name] # type: ignore
if self.__cog_app_commands_group__:
children.append(app_command) # type: ignore # Somehow it thinks it can be None here
children.append(app_command)
if Cog._get_overridden_method(self.cog_app_command_error) is not None:
error_handler = self.cog_app_command_error
@ -376,7 +399,7 @@ class Cog(metaclass=CogMeta):
if len(mapping) > 25:
raise TypeError('maximum number of application command children exceeded')
self.__cog_app_commands_group__._children = mapping # type: ignore # Variance issue
self.__cog_app_commands_group__._children = mapping
return self
@ -549,6 +572,8 @@ class Cog(metaclass=CogMeta):
Subclasses must replace this if they want special unloading behaviour.
Exceptions raised in this method are ignored during extension unloading.
.. versionchanged:: 2.0
This method can now be a :term:`coroutine`.
@ -585,6 +610,18 @@ class Cog(metaclass=CogMeta):
"""
return True
@_cog_special_method
def interaction_check(self, interaction: discord.Interaction[ClientT], /) -> bool:
"""A special method that registers as a :func:`discord.app_commands.check`
for every app command and subcommand in this cog.
This function **can** be a coroutine and must take a sole parameter,
``interaction``, to represent the :class:`~discord.Interaction`.
.. versionadded:: 2.0
"""
return True
@_cog_special_method
async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
"""|coro|
@ -738,7 +775,7 @@ class Cog(metaclass=CogMeta):
try:
await maybe_coroutine(self.cog_unload)
except Exception:
pass
_log.exception('Ignoring exception in cog unload for Cog %r (%r)', cls, self.qualified_name)
class GroupCog(Cog):

222
discord/ext/commands/context.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, List, Optional, TypeVar, Union, Sequence, Type
from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, List, Optional, TypeVar, Union, Sequence, Type, overload
import discord.abc
import discord.utils
@ -50,6 +50,7 @@ if TYPE_CHECKING:
from discord.message import MessageReference, PartialMessage
from discord.ui import View
from discord.types.interactions import ApplicationCommandInteractionData
from discord.poll import Poll
from .cog import Cog
from .core import Command
@ -81,16 +82,20 @@ def is_cog(obj: Any) -> TypeGuard[Cog]:
return hasattr(obj, '__cog_commands__')
class DeferTyping:
class DeferTyping(Generic[BotT]):
def __init__(self, ctx: Context[BotT], *, ephemeral: bool):
self.ctx: Context[BotT] = ctx
self.ephemeral: bool = ephemeral
async def do_defer(self) -> None:
if self.ctx.interaction and not self.ctx.interaction.response.is_done():
await self.ctx.interaction.response.defer(ephemeral=self.ephemeral)
def __await__(self) -> Generator[Any, None, None]:
return self.ctx.defer(ephemeral=self.ephemeral).__await__()
return self.do_defer().__await__()
async def __aenter__(self) -> None:
await self.ctx.defer(ephemeral=self.ephemeral)
await self.do_defer()
async def __aexit__(
self,
@ -251,7 +256,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if command is None:
raise ValueError('interaction does not have command data')
bot: BotT = interaction.client # type: ignore
bot: BotT = interaction.client
data: ApplicationCommandInteractionData = interaction.data # type: ignore
if interaction.message is None:
synthetic_payload = {
@ -430,6 +435,14 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return None
return self.command.cog
@property
def filesize_limit(self) -> int:
""":class:`int`: Returns the maximum number of bytes files can have when uploaded to this guild or DM channel associated with this context.
.. versionadded:: 2.3
"""
return self.guild.filesize_limit if self.guild is not None else discord.utils.DEFAULT_FILE_SIZE_LIMIT_BYTES
@discord.utils.cached_property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
@ -464,7 +477,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0
"""
if self.channel.type is ChannelType.private:
if self.interaction is None and self.channel.type is ChannelType.private:
return Permissions._dm_permissions()
if not self.interaction:
# channel and author will always match relevant types here
@ -498,7 +511,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0
"""
channel = self.channel
if channel.type == ChannelType.private:
if self.interaction is None and channel.type == ChannelType.private:
return Permissions._dm_permissions()
if not self.interaction:
# channel and me will always match relevant types here
@ -615,6 +628,94 @@ class Context(discord.abc.Messageable, Generic[BotT]):
except CommandError as e:
await cmd.on_help_command_error(self, e)
@overload
async def reply(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@overload
async def reply(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@overload
async def reply(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@overload
async def reply(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
"""|coro|
@ -650,7 +751,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
else:
return await self.send(content, **kwargs)
def typing(self, *, ephemeral: bool = False) -> Union[Typing, DeferTyping]:
def typing(self, *, ephemeral: bool = False) -> Union[Typing, DeferTyping[BotT]]:
"""Returns an asynchronous context manager that allows you to send a typing indicator to
the destination for an indefinite period of time, or 10 seconds if the context manager
is called using ``await``.
@ -716,6 +817,94 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if self.interaction:
await self.interaction.response.defer(ephemeral=ephemeral)
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...
async def send(
self,
content: Optional[str] = None,
@ -735,6 +924,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
suppress_embeds: bool = False,
ephemeral: bool = False,
silent: bool = False,
poll: Optional[Poll] = None,
) -> Message:
"""|coro|
@ -824,6 +1014,13 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.2
poll: Optional[:class:`~discord.Poll`]
The poll to send with this message.
.. versionadded:: 2.4
.. versionchanged:: 2.6
This can now be ``None`` and defaults to ``None`` instead of ``MISSING``.
Raises
--------
~discord.HTTPException
@ -861,6 +1058,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
view=view,
suppress_embeds=suppress_embeds,
silent=silent,
poll=poll,
) # type: ignore # The overloads don't support Optional but the implementation does
# Convert the kwargs from None to MISSING to appease the remaining implementations
@ -876,13 +1074,17 @@ class Context(discord.abc.Messageable, Generic[BotT]):
'suppress_embeds': suppress_embeds,
'ephemeral': ephemeral,
'silent': silent,
'poll': MISSING if poll is None else poll,
}
if self.interaction.response.is_done():
msg = await self.interaction.followup.send(**kwargs, wait=True)
else:
await self.interaction.response.send_message(**kwargs)
msg = await self.interaction.original_response()
response = await self.interaction.response.send_message(**kwargs)
if not isinstance(response.resource, discord.InteractionMessage):
msg = await self.interaction.original_response()
else:
msg = response.resource
if delete_after is not None:
await msg.delete(delay=delete_after)

182
discord/ext/commands/converter.py

@ -82,6 +82,7 @@ __all__ = (
'GuildChannelConverter',
'GuildStickerConverter',
'ScheduledEventConverter',
'SoundboardSoundConverter',
'clean_content',
'Greedy',
'Range',
@ -126,6 +127,10 @@ class Converter(Protocol[T_co]):
raise a :exc:`.CommandError` derived exception as it will
properly propagate to the error handlers.
Note that if this method is called manually, :exc:`Exception`
should be caught to handle the cases where a subclass does
not explicitly inherit from :exc:`.CommandError`.
Parameters
-----------
ctx: :class:`.Context`
@ -186,9 +191,11 @@ class MemberConverter(IDConverter[discord.Member]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name#discrim
4. Lookup by name
5. Lookup by nickname
3. Lookup by username#discriminator (deprecated).
4. Lookup by username#0 (deprecated, only gets users that migrated from their discriminator).
5. Lookup by user name.
6. Lookup by global name.
7. Lookup by guild nickname.
.. versionchanged:: 1.5
Raise :exc:`.MemberNotFound` instead of generic :exc:`.BadArgument`
@ -196,17 +203,29 @@ class MemberConverter(IDConverter[discord.Member]):
.. versionchanged:: 1.5.1
This converter now lazily fetches members from the gateway and HTTP APIs,
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
.. deprecated:: 2.3
Looking up users by discriminator will be removed in a future version due to
the removal of discriminators in an API change.
"""
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
username, _, discriminator = argument.rpartition('#')
# If # isn't found then "discriminator" actually has the username
if not username:
discriminator, username = username, discriminator
if discriminator == '0' or (len(discriminator) == 4 and discriminator.isdigit()):
lookup = username
predicate = lambda m: m.name == username and m.discriminator == discriminator
else:
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
lookup = argument
predicate = lambda m: m.name == argument or m.global_name == argument or m.nick == argument
members = await guild.query_members(lookup, limit=100, cache=cache)
return discord.utils.find(predicate, members)
async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]:
ws = bot._get_websocket(shard_id=guild.shard_id)
@ -273,8 +292,10 @@ class UserConverter(IDConverter[discord.User]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name#discrim
4. Lookup by name
3. Lookup by username#discriminator (deprecated).
4. Lookup by username#0 (deprecated, only gets users that migrated from their discriminator).
5. Lookup by user name.
6. Lookup by global name.
.. versionchanged:: 1.5
Raise :exc:`.UserNotFound` instead of generic :exc:`.BadArgument`
@ -282,6 +303,10 @@ class UserConverter(IDConverter[discord.User]):
.. versionchanged:: 1.6
This converter now lazily fetches users from the HTTP APIs if an ID is passed
and it's not available in cache.
.. deprecated:: 2.3
Looking up users by discriminator will be removed in a future version due to
the removal of discriminators in an API change.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
@ -300,25 +325,18 @@ class UserConverter(IDConverter[discord.User]):
return result # type: ignore
arg = argument
username, _, discriminator = argument.rpartition('#')
# Remove the '@' character if this is the first character from the argument
if arg[0] == '@':
# Remove first character
arg = arg[1:]
# If # isn't found then "discriminator" actually has the username
if not username:
discriminator, username = username, discriminator
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#':
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
result = discord.utils.find(predicate, state._users.values())
if result is not None:
return result
if discriminator == '0' or (len(discriminator) == 4 and discriminator.isdigit()):
predicate = lambda u: u.name == username and u.discriminator == discriminator
else:
predicate = lambda u: u.name == argument or u.global_name == argument
predicate = lambda u: u.name == arg
result = discord.utils.find(predicate, state._users.values())
if result is None:
raise UserNotFound(argument)
@ -425,19 +443,36 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name.
3. Lookup by channel URL.
4. Lookup by name.
.. versionadded:: 2.0
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod
def _parse_from_url(argument: str) -> Optional[re.Match[str]]:
link_regex = re.compile(
r'https?://(?:(?:ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?:[0-9]{15,20}|@me)'
r'/([0-9]{15,20})(?:/(?:[0-9]{15,20})/?)?$'
)
return link_regex.match(argument)
@staticmethod
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = (
IDConverter._get_id_match(argument)
or re.match(r'<#([0-9]{15,20})>$', argument)
or GuildChannelConverter._parse_from_url(argument)
)
result = None
guild = ctx.guild
@ -467,7 +502,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
@staticmethod
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = (
IDConverter._get_id_match(argument)
or re.match(r'<#([0-9]{15,20})>$', argument)
or GuildChannelConverter._parse_from_url(argument)
)
result = None
guild = ctx.guild
@ -497,10 +536,14 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
3. Lookup by channel URL.
4. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
@ -517,10 +560,14 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
3. Lookup by channel URL.
4. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
@ -539,7 +586,11 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
3. Lookup by channel URL.
4. Lookup by name
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
@ -556,7 +607,11 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
3. Lookup by channel URL.
4. Lookup by name
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
@ -575,9 +630,13 @@ class ThreadConverter(IDConverter[discord.Thread]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name.
3. Lookup by channel URL.
4. Lookup by name.
.. versionadded: 2.0
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
@ -594,9 +653,13 @@ class ForumChannelConverter(IDConverter[discord.ForumChannel]):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
3. Lookup by channel URL.
4. Lookup by name
.. versionadded:: 2.0
.. versionchanged:: 2.4
Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels.
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ForumChannel:
@ -889,6 +952,44 @@ class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
return result
class SoundboardSoundConverter(IDConverter[discord.SoundboardSound]):
"""Converts to a :class:`~discord.SoundboardSound`.
Lookups are done for the local guild if available. Otherwise, for a DM context,
lookup is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by name.
.. versionadded:: 2.5
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.SoundboardSound:
guild = ctx.guild
match = self._get_id_match(argument)
result = None
if match:
# ID match
sound_id = int(match.group(1))
if guild:
result = guild.get_soundboard_sound(sound_id)
else:
result = ctx.bot.get_soundboard_sound(sound_id)
else:
# lookup by name
if guild:
result = discord.utils.get(guild.soundboard_sounds, name=argument)
else:
result = discord.utils.get(ctx.bot.soundboard_sounds, name=argument)
if result is None:
raise SoundboardSoundNotFound(argument)
return result
class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of
said content.
@ -1024,7 +1125,7 @@ class Greedy(List[T]):
args = getattr(converter, '__args__', ())
if discord.utils.PY_310 and converter.__class__ is types.UnionType: # type: ignore
converter = Union[args] # type: ignore
converter = Union[args]
origin = getattr(converter, '__origin__', None)
@ -1037,7 +1138,7 @@ class Greedy(List[T]):
if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
return cls(converter=converter)
return cls(converter=converter) # type: ignore
@property
def constructed_converter(self) -> Any:
@ -1172,7 +1273,7 @@ def _convert_to_bool(argument: str) -> bool:
raise BadBoolArgument(lowered)
_GenericAlias = type(List[T])
_GenericAlias = type(List[T]) # type: ignore
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
@ -1201,6 +1302,7 @@ CONVERTER_MAPPING: Dict[type, Any] = {
discord.GuildSticker: GuildStickerConverter,
discord.ScheduledEvent: ScheduledEventConverter,
discord.ForumChannel: ForumChannelConverter,
discord.SoundboardSound: SoundboardSoundConverter,
}
@ -1223,7 +1325,7 @@ async def _actual_conversion(ctx: Context[BotT], converter: Any, argument: str,
else:
return await converter().convert(ctx, argument)
elif isinstance(converter, Converter):
return await converter.convert(ctx, argument) # type: ignore
return await converter.convert(ctx, argument)
except CommandError:
raise
except Exception as exc:
@ -1330,7 +1432,7 @@ async def run_converters(ctx: Context[BotT], converter: Any, argument: str, para
return value
# if we're here, then we failed to match all the literals
raise BadLiteralArgument(param, literal_args, errors)
raise BadLiteralArgument(param, literal_args, errors, argument)
# This must be the last if-clause in the chain of origin checking
# Nearly every type is a generic type within the typing library

4
discord/ext/commands/cooldowns.py

@ -27,11 +27,11 @@ from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING
from discord.enums import Enum
from discord.abc import PrivateChannel
import time
import asyncio
from collections import deque
from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
from .context import Context
from discord.app_commands import Cooldown as Cooldown
@ -71,7 +71,7 @@ class BucketType(Enum):
elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category:
return (msg.channel.category or msg.channel).id # type: ignore
return (getattr(msg.channel, 'category', None) or msg.channel).id
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role

39
discord/ext/commands/core.py

@ -151,6 +151,7 @@ def get_signature_parameters(
parameter._default = default.default
parameter._description = default._description
parameter._displayed_default = default._displayed_default
parameter._displayed_name = default._displayed_name
annotation = parameter.annotation
@ -194,8 +195,13 @@ def extract_descriptions_from_docstring(function: Callable[..., Any], params: Di
description, param_docstring = divide
for match in NUMPY_DOCSTRING_ARG_REGEX.finditer(param_docstring):
name = match.group('name')
if name not in params:
continue
is_display_name = discord.utils.get(params.values(), displayed_name=name)
if is_display_name:
name = is_display_name.name
else:
continue
param = params[name]
if param.description is None:
@ -455,7 +461,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# bandaid for the fact that sometimes parent can be the bot instance
parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None # type: ignore # Does not recognise mixin usage
self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
self._before_invoke: Optional[Hook] = None
try:
@ -1169,7 +1175,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ''
result = []
for name, param in params.items():
for param in params.values():
name = param.displayed_name or param.name
greedy = isinstance(param.converter, Greedy)
optional = False # postpone evaluation of if it's an optional argument
@ -1277,7 +1285,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# since we have no checks, then we just return True.
return True
return await discord.utils.async_all(predicate(ctx) for predicate in predicates)
return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore
finally:
ctx.command = original
@ -1996,7 +2004,7 @@ def check_any(*checks: Check[ContextT]) -> Check[ContextT]:
# if we're here, all checks failed
raise CheckAnyFailure(unwrapped, errors)
return check(predicate) # type: ignore
return check(predicate)
def has_role(item: Union[int, str], /) -> Check[Any]:
@ -2036,7 +2044,7 @@ def has_role(item: Union[int, str], /) -> Check[Any]:
# ctx.guild is None doesn't narrow ctx.author to Member
if isinstance(item, int):
role = discord.utils.get(ctx.author.roles, id=item) # type: ignore
role = ctx.author.get_role(item) # type: ignore
else:
role = discord.utils.get(ctx.author.roles, name=item) # type: ignore
if role is None:
@ -2083,8 +2091,12 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
if any(
ctx.author.get_role(item) is not None
if isinstance(item, int)
else discord.utils.get(ctx.author.roles, name=item) is not None
for item in items
):
return True
raise MissingAnyRole(list(items))
@ -2113,11 +2125,10 @@ def bot_has_role(item: int, /) -> Callable[[T], T]:
if ctx.guild is None:
raise NoPrivateMessage()
me = ctx.me
if isinstance(item, int):
role = discord.utils.get(me.roles, id=item)
role = ctx.me.get_role(item)
else:
role = discord.utils.get(me.roles, name=item)
role = discord.utils.get(ctx.me.roles, name=item)
if role is None:
raise BotMissingRole(item)
return True
@ -2144,8 +2155,10 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]:
raise NoPrivateMessage()
me = ctx.me
getter = functools.partial(discord.utils.get, me.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
if any(
me.get_role(item) is not None if isinstance(item, int) else discord.utils.get(me.roles, name=item) is not None
for item in items
):
return True
raise BotMissingAnyRole(list(items))

82
discord/ext/commands/errors.py

@ -24,9 +24,12 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, Generic
from discord.errors import ClientException, DiscordException
from discord.utils import _human_join
from ._types import BotT
if TYPE_CHECKING:
from discord.abc import GuildChannel
@ -34,7 +37,6 @@ if TYPE_CHECKING:
from discord.types.snowflake import Snowflake, SnowflakeList
from discord.app_commands import AppCommandError
from ._types import BotT
from .context import Context
from .converter import Converter
from .cooldowns import BucketType, Cooldown
@ -74,6 +76,7 @@ __all__ = (
'EmojiNotFound',
'GuildStickerNotFound',
'ScheduledEventNotFound',
'SoundboardSoundNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
@ -182,7 +185,7 @@ class MissingRequiredArgument(UserInputError):
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.')
super().__init__(f'{param.displayed_name or param.name} is a required argument that is missing.')
class MissingRequiredAttachment(UserInputError):
@ -201,7 +204,7 @@ class MissingRequiredAttachment(UserInputError):
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing an attachment.')
super().__init__(f'{param.displayed_name or param.name} is a required argument that is missing an attachment.')
class TooManyArguments(UserInputError):
@ -233,7 +236,7 @@ class CheckFailure(CommandError):
pass
class CheckAnyFailure(CheckFailure):
class CheckAnyFailure(Generic[BotT], CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
This inherits from :exc:`CheckFailure`.
@ -563,6 +566,24 @@ class ScheduledEventNotFound(BadArgument):
super().__init__(f'ScheduledEvent "{argument}" not found.')
class SoundboardSoundNotFound(BadArgument):
"""Exception raised when the bot can not find the soundboard sound.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.5
Attributes
-----------
argument: :class:`str`
The sound supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'SoundboardSound "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@ -758,12 +779,7 @@ class MissingAnyRole(CheckFailure):
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = _human_join(missing)
message = f'You are missing at least one of the required roles: {fmt}'
super().__init__(message)
@ -788,12 +804,7 @@ class BotMissingAnyRole(CheckFailure):
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(', '.join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = _human_join(missing)
message = f'Bot is missing at least one of the required roles: {fmt}'
super().__init__(message)
@ -832,11 +843,7 @@ class MissingPermissions(CheckFailure):
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(', '.join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
fmt = _human_join(missing, final='and')
message = f'You are missing {fmt} permission(s) to run this command.'
super().__init__(message, *args)
@ -857,11 +864,7 @@ class BotMissingPermissions(CheckFailure):
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(', '.join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
fmt = _human_join(missing, final='and')
message = f'Bot requires {fmt} permission(s) to run this command.'
super().__init__(message, *args)
@ -896,12 +899,8 @@ class BadUnionArgument(UserInputError):
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
fmt = _human_join(to_string)
super().__init__(f'Could not convert "{param.displayed_name or param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
@ -920,20 +919,21 @@ class BadLiteralArgument(UserInputError):
A tuple of values compared against in conversion, in order of failure.
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
argument: :class:`str`
The argument's value that failed to be converted. Defaults to an empty string.
.. versionadded:: 2.3
"""
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError], argument: str = "") -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
self.errors: List[CommandError] = errors
self.argument: str = argument
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
fmt = _human_join(to_string)
super().__init__(f'Could not convert "{param.displayed_name or param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError):
@ -1081,7 +1081,7 @@ class ExtensionNotFound(ExtensionError):
"""
def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.'
msg = f'Extension {name!r} could not be loaded or found.'
super().__init__(msg, name=name)

45
discord/ext/commands/flags.py

@ -79,6 +79,10 @@ class Flag:
description: :class:`str`
The description of the flag. Shown for hybrid commands when they're
used as application commands.
positional: :class:`bool`
Whether the flag is positional or not. There can only be one positional flag.
.. versionadded:: 2.4
"""
name: str = MISSING
@ -89,6 +93,7 @@ class Flag:
max_args: int = MISSING
override: bool = MISSING
description: str = MISSING
positional: bool = MISSING
cast_to_dict: bool = False
@property
@ -109,6 +114,7 @@ def flag(
override: bool = MISSING,
converter: Any = MISSING,
description: str = MISSING,
positional: bool = MISSING,
) -> Any:
"""Override default functionality and parameters of the underlying :class:`FlagConverter`
class attributes.
@ -136,6 +142,10 @@ def flag(
description: :class:`str`
The description of the flag. Shown for hybrid commands when they're
used as application commands.
positional: :class:`bool`
Whether the flag is positional or not. There can only be one positional flag.
.. versionadded:: 2.4
"""
return Flag(
name=name,
@ -145,6 +155,7 @@ def flag(
override=override,
annotation=converter,
description=description,
positional=positional,
)
@ -171,6 +182,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
positional: Optional[Flag] = None
for name, annotation in annotations.items():
flag = namespace.pop(name, MISSING)
if isinstance(flag, Flag):
@ -183,6 +195,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.name is MISSING:
flag.name = name
if flag.positional:
if positional is not None:
raise TypeError(f"{flag.name!r} positional flag conflicts with {positional.name!r} flag.")
positional = flag
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
@ -270,6 +287,7 @@ class FlagsMeta(type):
__commands_flag_case_insensitive__: bool
__commands_flag_delimiter__: str
__commands_flag_prefix__: str
__commands_flag_positional__: Optional[Flag]
def __new__(
cls,
@ -280,7 +298,7 @@ class FlagsMeta(type):
case_insensitive: bool = MISSING,
delimiter: str = MISSING,
prefix: str = MISSING,
) -> Self:
) -> FlagsMeta:
attrs['__commands_is_flag__'] = True
try:
@ -324,9 +342,13 @@ class FlagsMeta(type):
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
positional: Optional[Flag] = None
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
if flag.positional:
positional = flag
attrs['__commands_flag_positional__'] = positional
forbidden = set(delimiter).union(prefix)
for flag_name in flags:
@ -421,7 +443,7 @@ async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation
return await convert_flag(ctx, argument, flag, annotation)
elif origin is Union and type(None) in annotation.__args__:
# typing.Optional[x]
annotation = Union[tuple(arg for arg in annotation.__args__ if arg is not type(None))] # type: ignore
annotation = Union[tuple(arg for arg in annotation.__args__ if arg is not type(None))]
return await run_converters(ctx, annotation, argument, param)
elif origin is dict:
# typing.Dict[K, V] -> typing.Tuple[K, V]
@ -485,7 +507,7 @@ class FlagConverter(metaclass=FlagsMeta):
for flag in flags.values():
if callable(flag.default):
# Type checker does not understand that flag.default is a Callable
default = await maybe_coroutine(flag.default, ctx) # type: ignore
default = await maybe_coroutine(flag.default, ctx)
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
@ -500,10 +522,25 @@ class FlagConverter(metaclass=FlagsMeta):
result: Dict[str, List[str]] = {}
flags = cls.__commands_flags__
aliases = cls.__commands_flag_aliases__
positional_flag = cls.__commands_flag_positional__
last_position = 0
last_flag: Optional[Flag] = None
case_insensitive = cls.__commands_flag_case_insensitive__
if positional_flag is not None:
match = cls.__commands_flag_regex__.search(argument)
if match is not None:
begin, end = match.span(0)
value = argument[:begin].strip()
else:
value = argument.strip()
last_position = len(argument)
if value:
name = positional_flag.name.casefold() if case_insensitive else positional_flag.name
result[name] = [value]
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
@ -600,7 +637,7 @@ class FlagConverter(metaclass=FlagsMeta):
else:
if callable(flag.default):
# Type checker does not understand flag.default is a Callable
default = await maybe_coroutine(flag.default, ctx) # type: ignore
default = await maybe_coroutine(flag.default, ctx)
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)

15
discord/ext/commands/help.py

@ -294,6 +294,14 @@ class _HelpCommandImpl(Command):
cog.walk_commands = cog.walk_commands.__wrapped__
self.cog = None
# Revert `on_error` to use the original one in case of race conditions
self.on_error = self._injected.on_help_command_error
def update(self, **kwargs: Any) -> None:
cog = self.cog
self.__init__(self._original, **dict(self.__original_kwargs__, **kwargs))
self.cog = cog
class HelpCommand:
r"""The base implementation for help command formatting.
@ -374,9 +382,8 @@ class HelpCommand:
return obj
def _add_to_bot(self, bot: BotBase) -> None:
command = _HelpCommandImpl(self, **self.command_attrs)
bot.add_command(command)
self._command_impl = command
self._command_impl.update(**self.command_attrs)
bot.add_command(self._command_impl)
def _remove_from_bot(self, bot: BotBase) -> None:
bot.remove_command(self._command_impl.name)
@ -1166,7 +1173,7 @@ class DefaultHelpCommand(HelpCommand):
get_width = discord.utils._string_width
for argument in arguments:
name = argument.name
name = argument.displayed_name or argument.name
width = max_size - (get_width(name) - len(name))
entry = f'{self.indent * " "}{name:<{width}} {argument.description or self.default_argument_description}'
# we do not want to shorten the default value, if any.

67
discord/ext/commands/hybrid.py

@ -43,7 +43,7 @@ import inspect
from discord import app_commands
from discord.utils import MISSING, maybe_coroutine, async_all
from .core import Command, Group
from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError
from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError, DisabledCommand
from .converter import Converter, Range, Greedy, run_converters, CONVERTER_MAPPING
from .parameters import Parameter
from .flags import is_flag, FlagConverter
@ -72,9 +72,9 @@ __all__ = (
T = TypeVar('T')
U = TypeVar('U')
CogT = TypeVar('CogT', bound='Cog')
CommandT = TypeVar('CommandT', bound='Command')
CommandT = TypeVar('CommandT', bound='Command[Any, ..., Any]')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
GroupT = TypeVar('GroupT', bound='Group[Any, ..., Any]')
_NoneType = type(None)
if TYPE_CHECKING:
@ -203,9 +203,9 @@ def replace_parameter(
# Fallback to see if the behaviour needs changing
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', [])
if isinstance(converter, Range):
if isinstance(converter, Range): # type: ignore # Range is not an Annotation at runtime
r = converter
param = param.replace(annotation=app_commands.Range[r.annotation, r.min, r.max])
param = param.replace(annotation=app_commands.Range[r.annotation, r.min, r.max]) # type: ignore
elif isinstance(converter, Greedy):
# Greedy is "optional" in ext.commands
# However, in here, it probably makes sense to make it required.
@ -234,6 +234,12 @@ def replace_parameter(
descriptions[name] = flag.description
if flag.name != flag.attribute:
renames[name] = flag.name
if pseudo.default is not pseudo.empty:
# This ensures the default is wrapped around _CallableDefault if callable
# else leaves it as-is.
pseudo = pseudo.replace(
default=_CallableDefault(flag.default) if callable(flag.default) else flag.default
)
mapping[name] = pseudo
@ -251,7 +257,7 @@ def replace_parameter(
inner = args[0]
is_inner_transformer = is_transformer(inner)
if is_converter(inner) and not is_inner_transformer:
param = param.replace(annotation=Optional[ConverterTransformer(inner, original)]) # type: ignore
param = param.replace(annotation=Optional[ConverterTransformer(inner, original)])
else:
raise
elif origin:
@ -283,7 +289,7 @@ def replace_parameters(
param = param.replace(default=default)
if isinstance(param.default, Parameter):
# If we're here, then then it hasn't been handled yet so it should be removed completely
# If we're here, then it hasn't been handled yet so it should be removed completely
param = param.replace(default=parameter.empty)
# Flags are flattened out and thus don't get their parameter in the actual mapping
@ -297,14 +303,20 @@ def replace_parameters(
class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
def __init__(self, wrapped: Union[HybridCommand[CogT, Any, T], HybridGroup[CogT, Any, T]]) -> None:
__commands_is_hybrid_app_command__: ClassVar[bool] = True
def __init__(
self,
wrapped: Union[HybridCommand[CogT, ..., T], HybridGroup[CogT, ..., T]],
name: Optional[Union[str, app_commands.locale_str]] = None,
) -> None:
signature = inspect.signature(wrapped.callback)
params = replace_parameters(wrapped.params, wrapped.callback, signature)
wrapped.callback.__signature__ = signature.replace(parameters=params)
nsfw = getattr(wrapped.callback, '__discord_app_commands_is_nsfw__', False)
try:
super().__init__(
name=wrapped._locale_name or wrapped.name,
name=name or wrapped._locale_name or wrapped.name,
callback=wrapped.callback, # type: ignore # Signature doesn't match but we're overriding the invoke
description=wrapped._locale_description or wrapped.description or wrapped.short_doc or '',
nsfw=nsfw,
@ -312,7 +324,7 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
finally:
del wrapped.callback.__signature__
self.wrapped: Union[HybridCommand[CogT, Any, T], HybridGroup[CogT, Any, T]] = wrapped
self.wrapped: Union[HybridCommand[CogT, ..., T], HybridGroup[CogT, ..., T]] = wrapped
self.binding: Optional[CogT] = wrapped.cog
# This technically means only one flag converter is supported
self.flag_converter: Optional[Tuple[str, Type[FlagConverter]]] = getattr(
@ -398,7 +410,7 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
if self.binding is not None:
try:
# Type checker does not like runtime attribute retrieval
check: AppCommandCheck = self.binding.interaction_check # type: ignore
check: AppCommandCheck = self.binding.interaction_check
except AttributeError:
pass
else:
@ -412,10 +424,10 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
if not ret:
return False
if self.checks and not await async_all(f(interaction) for f in self.checks):
if self.checks and not await async_all(f(interaction) for f in self.checks): # type: ignore
return False
if self.wrapped.checks and not await async_all(f(ctx) for f in self.wrapped.checks):
if self.wrapped.checks and not await async_all(f(ctx) for f in self.wrapped.checks): # type: ignore
return False
return True
@ -520,6 +532,9 @@ class HybridCommand(Command[CogT, P, T]):
self.app_command.binding = value
async def can_run(self, ctx: Context[BotT], /) -> bool:
if not self.enabled:
raise DisabledCommand(f'{self.name} command is disabled')
if ctx.interaction is not None and self.app_command:
return await self.app_command._check_can_run(ctx.interaction)
else:
@ -594,6 +609,8 @@ class HybridGroup(Group[CogT, P, T]):
application command groups cannot be invoked, this creates a subcommand within
the group that can be invoked with the given group callback. If ``None``
then no fallback command is given. Defaults to ``None``.
fallback_locale: Optional[:class:`~discord.app_commands.locale_str`]
The fallback command name's locale string, if available.
"""
__commands_is_hybrid__: ClassVar[bool] = True
@ -603,7 +620,7 @@ class HybridGroup(Group[CogT, P, T]):
*args: Any,
name: Union[str, app_commands.locale_str] = MISSING,
description: Union[str, app_commands.locale_str] = MISSING,
fallback: Optional[str] = None,
fallback: Optional[Union[str, app_commands.locale_str]] = None,
**attrs: Any,
) -> None:
name, name_locale = (name.message, name) if isinstance(name, app_commands.locale_str) else (name, None)
@ -631,7 +648,12 @@ class HybridGroup(Group[CogT, P, T]):
# However, Python does not have conditional typing so it's very hard to
# make this type depend on the with_app_command bool without a lot of needless repetition
self.app_command: app_commands.Group = MISSING
fallback, fallback_locale = (
(fallback.message, fallback) if isinstance(fallback, app_commands.locale_str) else (fallback, None)
)
self.fallback: Optional[str] = fallback
self.fallback_locale: Optional[app_commands.locale_str] = fallback_locale
if self.with_app_command:
guild_ids = attrs.pop('guild_ids', None) or getattr(
@ -640,6 +662,8 @@ class HybridGroup(Group[CogT, P, T]):
guild_only = getattr(self.callback, '__discord_app_commands_guild_only__', False)
default_permissions = getattr(self.callback, '__discord_app_commands_default_permissions__', None)
nsfw = getattr(self.callback, '__discord_app_commands_is_nsfw__', False)
contexts = getattr(self.callback, '__discord_app_commands_contexts__', MISSING)
installs = getattr(self.callback, '__discord_app_commands_installation_types__', MISSING)
self.app_command = app_commands.Group(
name=self._locale_name or self.name,
description=self._locale_description or self.description or self.short_doc or '',
@ -647,6 +671,8 @@ class HybridGroup(Group[CogT, P, T]):
guild_only=guild_only,
default_permissions=default_permissions,
nsfw=nsfw,
allowed_installs=installs,
allowed_contexts=contexts,
)
# This prevents the group from re-adding the command at __init__
@ -654,8 +680,7 @@ class HybridGroup(Group[CogT, P, T]):
self.app_command.module = self.module
if fallback is not None:
command = HybridAppCommand(self)
command.name = fallback
command = HybridAppCommand(self, name=fallback_locale or fallback)
self.app_command.add_command(command)
@property
@ -890,7 +915,8 @@ def hybrid_command(
def decorator(func: CommandCallback[CogT, ContextT, P, T]) -> HybridCommand[CogT, P, T]:
if isinstance(func, Command):
raise TypeError('Callback is already a command.')
return HybridCommand(func, name=name, with_app_command=with_app_command, **attrs) # type: ignore # ???
# Pyright does not allow Command[Any] to be assigned to Command[CogT] despite it being okay here
return HybridCommand(func, name=name, with_app_command=with_app_command, **attrs) # type: ignore
return decorator
@ -908,6 +934,9 @@ def hybrid_group(
Parameters
-----------
name: Union[:class:`str`, :class:`~discord.app_commands.locale_str`]
The name to create the group with. By default this uses the
function name unchanged.
with_app_command: :class:`bool`
Whether to register the command also as an application command.
@ -917,9 +946,9 @@ def hybrid_group(
If the function is not a coroutine or is already a command.
"""
def decorator(func: CommandCallback[CogT, ContextT, P, T]):
def decorator(func: CommandCallback[CogT, ContextT, P, T]) -> HybridGroup[CogT, P, T]:
if isinstance(func, Command):
raise TypeError('Callback is already a command.')
return HybridGroup(func, name=name, with_app_command=with_app_command, **attrs)
return decorator # type: ignore
return decorator

51
discord/ext/commands/parameters.py

@ -87,7 +87,7 @@ class Parameter(inspect.Parameter):
.. versionadded:: 2.0
"""
__slots__ = ('_displayed_default', '_description', '_fallback')
__slots__ = ('_displayed_default', '_description', '_fallback', '_displayed_name')
def __init__(
self,
@ -97,6 +97,7 @@ class Parameter(inspect.Parameter):
annotation: Any = empty,
description: str = empty,
displayed_default: str = empty,
displayed_name: str = empty,
) -> None:
super().__init__(name=name, kind=kind, default=default, annotation=annotation)
self._name = name
@ -106,6 +107,10 @@ class Parameter(inspect.Parameter):
self._annotation = annotation
self._displayed_default = displayed_default
self._fallback = False
self._displayed_name = displayed_name
def __repr__(self) -> str:
return f'<{self.__class__.__name__} name={self._name!r} required={self.required}>'
def replace(
self,
@ -116,6 +121,7 @@ class Parameter(inspect.Parameter):
annotation: Any = MISSING,
description: str = MISSING,
displayed_default: Any = MISSING,
displayed_name: Any = MISSING,
) -> Self:
if name is MISSING:
name = self._name
@ -129,15 +135,20 @@ class Parameter(inspect.Parameter):
description = self._description
if displayed_default is MISSING:
displayed_default = self._displayed_default
if displayed_name is MISSING:
displayed_name = self._displayed_name
return self.__class__(
ret = self.__class__(
name=name,
kind=kind,
default=default,
annotation=annotation,
description=description,
displayed_default=displayed_default,
displayed_name=displayed_name,
)
ret._fallback = self._fallback
return ret
if not TYPE_CHECKING: # this is to prevent anything breaking if inspect internals change
name = _gen_property('name')
@ -169,7 +180,21 @@ class Parameter(inspect.Parameter):
if self._displayed_default is not empty:
return self._displayed_default
return None if self.required else str(self.default)
if self.required:
return None
if callable(self.default) or self.default is None:
return None
return str(self.default)
@property
def displayed_name(self) -> Optional[str]:
"""Optional[:class:`str`]: The name that is displayed to the user.
.. versionadded:: 2.3
"""
return self._displayed_name if self._displayed_name is not empty else None
async def get_default(self, ctx: Context[Any]) -> Any:
"""|coro|
@ -183,7 +208,7 @@ class Parameter(inspect.Parameter):
"""
# pre-condition: required is False
if callable(self.default):
return await maybe_coroutine(self.default, ctx) # type: ignore
return await maybe_coroutine(self.default, ctx)
return self.default
@ -193,8 +218,9 @@ def parameter(
default: Any = empty,
description: str = empty,
displayed_default: str = empty,
displayed_name: str = empty,
) -> Any:
r"""parameter(\*, converter=..., default=..., description=..., displayed_default=...)
r"""parameter(\*, converter=..., default=..., description=..., displayed_default=..., displayed_name=...)
A way to assign custom metadata for a :class:`Command`\'s parameter.
@ -221,7 +247,17 @@ def parameter(
The description of this parameter.
displayed_default: :class:`str`
The displayed default in :attr:`Command.signature`.
displayed_name: :class:`str`
The name that is displayed to the user.
.. versionadded:: 2.3
"""
if isinstance(default, Parameter):
if displayed_default is empty:
displayed_default = default._displayed_default
default = default._default
return Parameter(
name='empty',
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
@ -229,6 +265,7 @@ def parameter(
default=default,
description=description,
displayed_default=displayed_default,
displayed_name=displayed_name,
)
@ -240,12 +277,13 @@ class ParameterAlias(Protocol):
default: Any = empty,
description: str = empty,
displayed_default: str = empty,
displayed_name: str = empty,
) -> Any:
...
param: ParameterAlias = parameter
r"""param(\*, converter=..., default=..., description=..., displayed_default=...)
r"""param(\*, converter=..., default=..., description=..., displayed_default=..., displayed_name=...)
An alias for :func:`parameter`.
@ -279,6 +317,7 @@ CurrentGuild = parameter(
displayed_default='<this server>',
converter=GuildConverter,
)
CurrentGuild._fallback = True
class Signature(inspect.Signature):

22
discord/ext/tasks/__init__.py

@ -111,12 +111,17 @@ class SleepHandle:
self.loop: asyncio.AbstractEventLoop = loop
self.future: asyncio.Future[None] = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, self.future.set_result, True)
self.handle = loop.call_later(relative_delta, self._wrapped_set_result, self.future)
@staticmethod
def _wrapped_set_result(future: asyncio.Future) -> None:
if not future.done():
future.set_result(None)
def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self._wrapped_set_result, self.future)
def wait(self) -> asyncio.Future[Any]:
return self.future
@ -144,6 +149,7 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
name: Optional[str],
) -> None:
self.coro: LF = coro
self.reconnect: bool = reconnect
@ -165,6 +171,7 @@ class Loop(Generic[LF]):
self._is_being_cancelled = False
self._has_failed = False
self._stop_next_iteration = False
self._name: str = f'discord-ext-tasks: {coro.__qualname__}' if name is None else name
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
@ -282,6 +289,7 @@ class Loop(Generic[LF]):
time=self._time,
count=self.count,
reconnect=self.reconnect,
name=self._name,
)
copy._injected = obj
copy._before_loop = self._before_loop
@ -395,7 +403,7 @@ class Loop(Generic[LF]):
args = (self._injected, *args)
self._has_failed = False
self._task = asyncio.create_task(self._loop(*args, **kwargs))
self._task = asyncio.create_task(self._loop(*args, **kwargs), name=self._name)
return self._task
def stop(self) -> None:
@ -770,6 +778,7 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
name: Optional[str] = None,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@ -802,6 +811,12 @@ def loop(
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`.
name: Optional[:class:`str`]
The name to assign to the internal task. By default
it is assigned a name based off of the callable name
such as ``discord-ext-tasks: function_name``.
.. versionadded:: 2.4
Raises
--------
@ -821,6 +836,7 @@ def loop(
count=count,
time=time,
reconnect=reconnect,
name=name,
)
return decorator

2
discord/file.py

@ -111,7 +111,7 @@ class File:
else:
filename = getattr(fp, 'name', 'untitled')
self._filename, filename_spoiler = _strip_spoiler(filename)
self._filename, filename_spoiler = _strip_spoiler(filename) # type: ignore # pyright doesn't understand the above getattr
if spoiler is MISSING:
spoiler = filename_spoiler

765
discord/flags.py

@ -25,7 +25,22 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from functools import reduce
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type, TypeVar, overload
from operator import or_
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
overload,
)
from .enums import UserFlags
@ -43,6 +58,13 @@ __all__ = (
'ChannelFlags',
'AutoModPresets',
'MemberFlags',
'AppCommandContext',
'AttachmentFlags',
'RoleFlags',
'AppInstallationType',
'SKUFlags',
'EmbedFlags',
'InviteFlags',
)
BF = TypeVar('BF', bound='BaseFlags')
@ -115,7 +137,7 @@ class BaseFlags:
setattr(self, key, value)
@classmethod
def _from_value(cls, value):
def _from_value(cls, value: int) -> Self:
self = cls.__new__(cls)
self.value = value
return self
@ -238,6 +260,12 @@ class SystemChannelFlags(BaseFlags):
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
.. versionadded:: 2.0
Attributes
-----------
value: :class:`int`
@ -360,6 +388,12 @@ class MessageFlags(BaseFlags):
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
.. versionadded:: 2.0
.. versionadded:: 1.3
Attributes
@ -450,6 +484,22 @@ class MessageFlags(BaseFlags):
"""
return 4096
@flag_value
def voice(self):
""":class:`bool`: Returns ``True`` if the message is a voice message.
.. versionadded:: 2.3
"""
return 8192
@flag_value
def forwarded(self):
""":class:`bool`: Returns ``True`` if the message is a forwarded message.
.. versionadded:: 2.5
"""
return 16384
@fill_with_flags()
class PublicUserFlags(BaseFlags):
@ -500,6 +550,12 @@ class PublicUserFlags(BaseFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
.. versionadded:: 2.0
.. versionadded:: 1.4
Attributes
@ -684,6 +740,12 @@ class Intents(BaseFlags):
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
.. describe:: bool(b)
Returns whether any intent is enabled.
.. versionadded:: 2.0
Attributes
-----------
value: :class:`int`
@ -783,6 +845,7 @@ class Intents(BaseFlags):
- :attr:`User.name`
- :attr:`User.avatar`
- :attr:`User.discriminator`
- :attr:`User.global_name`
For more information go to the :ref:`member intent documentation <need_members_intent>`.
@ -816,9 +879,9 @@ class Intents(BaseFlags):
"""
return 1 << 2
@flag_value
@alias_flag_value
def emojis(self):
""":class:`bool`: Alias of :attr:`.emojis_and_stickers`.
""":class:`bool`: Alias of :attr:`.expressions`.
.. versionchanged:: 2.0
Changed to an alias.
@ -827,25 +890,43 @@ class Intents(BaseFlags):
@alias_flag_value
def emojis_and_stickers(self):
""":class:`bool`: Whether guild emoji and sticker related events are enabled.
""":class:`bool`: Alias of :attr:`.expressions`.
.. versionadded:: 2.0
.. versionchanged:: 2.5
Changed to an alias.
"""
return 1 << 3
@flag_value
def expressions(self):
""":class:`bool`: Whether guild emoji, sticker, and soundboard sound related events are enabled.
.. versionadded:: 2.5
This corresponds to the following events:
- :func:`on_guild_emojis_update`
- :func:`on_guild_stickers_update`
- :func:`on_soundboard_sound_create`
- :func:`on_soundboard_sound_update`
- :func:`on_soundboard_sound_delete`
This also corresponds to the following attributes and classes in terms of cache:
- :class:`Emoji`
- :class:`GuildSticker`
- :class:`SoundboardSound`
- :meth:`Client.get_emoji`
- :meth:`Client.get_sticker`
- :meth:`Client.get_soundboard_sound`
- :meth:`Client.emojis`
- :meth:`Client.stickers`
- :meth:`Client.soundboard_sounds`
- :attr:`Guild.emojis`
- :attr:`Guild.stickers`
- :attr:`Guild.soundboard_sounds`
"""
return 1 << 3
@ -1161,7 +1242,7 @@ class Intents(BaseFlags):
"""
return 1 << 16
@flag_value
@alias_flag_value
def auto_moderation(self):
""":class:`bool`: Whether auto moderation related events are enabled.
@ -1204,6 +1285,57 @@ class Intents(BaseFlags):
"""
return 1 << 21
@alias_flag_value
def polls(self):
""":class:`bool`: Whether guild and direct messages poll related events are enabled.
This is a shortcut to set or get both :attr:`guild_polls` and :attr:`dm_polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (both guilds and DMs)
- :func:`on_poll_vote_remove` (both guilds and DMs)
- :func:`on_raw_poll_vote_add` (both guilds and DMs)
- :func:`on_raw_poll_vote_remove` (both guilds and DMs)
.. versionadded:: 2.4
"""
return (1 << 24) | (1 << 25)
@flag_value
def guild_polls(self):
""":class:`bool`: Whether guild poll related events are enabled.
See also :attr:`dm_polls` and :attr:`polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (only for guilds)
- :func:`on_poll_vote_remove` (only for guilds)
- :func:`on_raw_poll_vote_add` (only for guilds)
- :func:`on_raw_poll_vote_remove` (only for guilds)
.. versionadded:: 2.4
"""
return 1 << 24
@flag_value
def dm_polls(self):
""":class:`bool`: Whether direct messages poll related events are enabled.
See also :attr:`guild_polls` and :attr:`polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (only for DMs)
- :func:`on_poll_vote_remove` (only for DMs)
- :func:`on_raw_poll_vote_add` (only for DMs)
- :func:`on_raw_poll_vote_remove` (only for DMs)
.. versionadded:: 2.4
"""
return 1 << 25
@fill_with_flags()
class MemberCacheFlags(BaseFlags):
@ -1269,6 +1401,12 @@ class MemberCacheFlags(BaseFlags):
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
.. versionadded:: 2.0
Attributes
-----------
value: :class:`int`
@ -1412,6 +1550,10 @@ class ApplicationFlags(BaseFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
.. versionadded:: 2.0
Attributes
@ -1421,6 +1563,15 @@ class ApplicationFlags(BaseFlags):
rather than using this raw value.
"""
@flag_value
def auto_mod_badge(self):
""":class:`bool`: Returns ``True`` if the application uses at least 100 automod rules across all guilds.
This shows up as a badge in the official client.
.. versionadded:: 2.3
"""
return 1 << 6
@flag_value
def gateway_presence(self):
""":class:`bool`: Returns ``True`` if the application is verified and is allowed to
@ -1538,6 +1689,10 @@ class ChannelFlags(BaseFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
.. versionadded:: 2.0
Attributes
@ -1561,16 +1716,49 @@ class ChannelFlags(BaseFlags):
"""
return 1 << 4
@flag_value
def hide_media_download_options(self):
""":class:`bool`: Returns ``True`` if the client hides embedded media download options in a :class:`ForumChannel`.
Only available in media channels.
.. versionadded:: 2.4
"""
return 1 << 15
class ArrayFlags(BaseFlags):
@classmethod
def _from_value(cls: Type[Self], value: List[int]) -> Self:
def _from_value(cls: Type[Self], value: Sequence[int]) -> Self:
self = cls.__new__(cls)
self.value = reduce(lambda a, b: a | (1 << b - 1), value, 0)
# This is a micro-optimization given the frequency this object can be created.
# (1).__lshift__ is used in place of lambda x: 1 << x
# prebinding to a method of a constant rather than define a lambda.
# Pairing this with map, is essentially equivalent to (1 << x for x in value)
# reduction using operator.or_ instead of defining a lambda each call
# Discord sends these starting with a value of 1
# Rather than subtract 1 from each element prior to left shift,
# we shift right by 1 once at the end.
self.value = reduce(or_, map((1).__lshift__, value), 0) >> 1
return self
def to_array(self) -> List[int]:
return [i + 1 for i in range(self.value.bit_length()) if self.value & (1 << i)]
def to_array(self, *, offset: int = 0) -> List[int]:
return [i + offset for i in range(self.value.bit_length()) if self.value & (1 << i)]
@classmethod
def all(cls: Type[Self]) -> Self:
"""A factory method that creates an instance of ArrayFlags with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1
self = cls.__new__(cls)
self.value = value
return self
@classmethod
def none(cls: Type[Self]) -> Self:
"""A factory method that creates an instance of ArrayFlags with everything disabled."""
self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE
return self
@fill_with_flags()
@ -1626,6 +1814,10 @@ class AutoModPresets(ArrayFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
value: :class:`int`
@ -1633,6 +1825,9 @@ class AutoModPresets(ArrayFlags):
rather than using this raw value.
"""
def to_array(self) -> List[int]:
return super().to_array(offset=1)
@flag_value
def profanity(self):
""":class:`bool`: Whether to use the preset profanity filter."""
@ -1648,21 +1843,144 @@ class AutoModPresets(ArrayFlags):
""":class:`bool`: Whether to use the preset slurs filter."""
return 1 << 2
@classmethod
def all(cls: Type[Self]) -> Self:
"""A factory method that creates a :class:`AutoModPresets` with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1
self = cls.__new__(cls)
self.value = value
return self
@classmethod
def none(cls: Type[Self]) -> Self:
"""A factory method that creates a :class:`AutoModPresets` with everything disabled."""
self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE
return self
@fill_with_flags()
class AppCommandContext(ArrayFlags):
r"""Wraps up the Discord :class:`~discord.app_commands.Command` execution context.
.. versionadded:: 2.4
.. container:: operations
.. describe:: x == y
Checks if two AppCommandContext flags are equal.
.. describe:: x != y
Checks if two AppCommandContext flags are not equal.
.. describe:: x | y, x |= y
Returns an AppCommandContext instance with all enabled flags from
both x and y.
.. describe:: x & y, x &= y
Returns an AppCommandContext instance with only flags enabled on
both x and y.
.. describe:: x ^ y, x ^= y
Returns an AppCommandContext instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns an AppCommandContext instance with all flags inverted from x
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
DEFAULT_VALUE = 3
@flag_value
def guild(self):
""":class:`bool`: Whether the context allows usage in a guild."""
return 1 << 0
@flag_value
def dm_channel(self):
""":class:`bool`: Whether the context allows usage in a DM channel."""
return 1 << 1
@flag_value
def private_channel(self):
""":class:`bool`: Whether the context allows usage in a DM or a GDM channel."""
return 1 << 2
@fill_with_flags()
class AppInstallationType(ArrayFlags):
r"""Represents the installation location of an application command.
.. versionadded:: 2.4
.. container:: operations
.. describe:: x == y
Checks if two AppInstallationType flags are equal.
.. describe:: x != y
Checks if two AppInstallationType flags are not equal.
.. describe:: x | y, x |= y
Returns an AppInstallationType instance with all enabled flags from
both x and y.
.. describe:: x & y, x &= y
Returns an AppInstallationType instance with only flags enabled on
both x and y.
.. describe:: x ^ y, x ^= y
Returns an AppInstallationType instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns an AppInstallationType instance with all flags inverted from x
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def guild(self):
""":class:`bool`: Whether the integration is a guild install."""
return 1 << 0
@flag_value
def user(self):
""":class:`bool`: Whether the integration is a user install."""
return 1 << 1
@fill_with_flags()
@ -1710,6 +2028,10 @@ class MemberFlags(BaseFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
@ -1737,3 +2059,398 @@ class MemberFlags(BaseFlags):
def started_onboarding(self):
""":class:`bool`: Returns ``True`` if the member has started onboarding."""
return 1 << 3
@flag_value
def guest(self):
""":class:`bool`: Returns ``True`` if the member is a guest and can only access
the voice channel they were invited to.
.. versionadded:: 2.5
"""
return 1 << 4
@flag_value
def started_home_actions(self):
""":class:`bool`: Returns ``True`` if the member has started Server Guide new member actions.
.. versionadded:: 2.5
"""
return 1 << 5
@flag_value
def completed_home_actions(self):
""":class:`bool`: Returns ``True`` if the member has completed Server Guide new member actions.
.. versionadded:: 2.5
"""
return 1 << 6
@flag_value
def automod_quarantined_username(self):
""":class:`bool`: Returns ``True`` if the member's username, nickname, or global name has been
blocked by AutoMod.
.. versionadded:: 2.5
"""
return 1 << 7
@flag_value
def dm_settings_upsell_acknowledged(self):
""":class:`bool`: Returns ``True`` if the member has dismissed the DM settings upsell.
.. versionadded:: 2.5
"""
return 1 << 9
@fill_with_flags()
class AttachmentFlags(BaseFlags):
r"""Wraps up the Discord Attachment flags
.. versionadded:: 2.4
.. container:: operations
.. describe:: x == y
Checks if two AttachmentFlags are equal.
.. describe:: x != y
Checks if two AttachmentFlags are not equal.
.. describe:: x | y, x |= y
Returns a AttachmentFlags instance with all enabled flags from
both x and y.
.. describe:: x & y, x &= y
Returns a AttachmentFlags instance with only flags enabled on
both x and y.
.. describe:: x ^ y, x ^= y
Returns a AttachmentFlags instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns a AttachmentFlags instance with all flags inverted from x.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def clip(self):
""":class:`bool`: Returns ``True`` if the attachment is a clip."""
return 1 << 0
@flag_value
def thumbnail(self):
""":class:`bool`: Returns ``True`` if the attachment is a thumbnail."""
return 1 << 1
@flag_value
def remix(self):
""":class:`bool`: Returns ``True`` if the attachment has been edited using the remix feature."""
return 1 << 2
@flag_value
def spoiler(self):
""":class:`bool`: Returns ``True`` if the attachment was marked as a spoiler.
.. versionadded:: 2.5
"""
return 1 << 3
@flag_value
def contains_explicit_media(self):
""":class:`bool`: Returns ``True`` if the attachment was flagged as sensitive content.
.. versionadded:: 2.5
"""
return 1 << 4
@flag_value
def animated(self):
""":class:`bool`: Returns ``True`` if the attachment is an animated image.
.. versionadded:: 2.5
"""
return 1 << 5
@fill_with_flags()
class RoleFlags(BaseFlags):
r"""Wraps up the Discord Role flags
.. versionadded:: 2.4
.. container:: operations
.. describe:: x == y
Checks if two RoleFlags are equal.
.. describe:: x != y
Checks if two RoleFlags are not equal.
.. describe:: x | y, x |= y
Returns a RoleFlags instance with all enabled flags from
both x and y.
.. describe:: x & y, x &= y
Returns a RoleFlags instance with only flags enabled on
both x and y.
.. describe:: x ^ y, x ^= y
Returns a RoleFlags instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns a RoleFlags instance with all flags inverted from x.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def in_prompt(self):
""":class:`bool`: Returns ``True`` if the role can be selected by members in an onboarding prompt."""
return 1 << 0
@fill_with_flags()
class SKUFlags(BaseFlags):
r"""Wraps up the Discord SKU flags
.. versionadded:: 2.4
.. container:: operations
.. describe:: x == y
Checks if two SKUFlags are equal.
.. describe:: x != y
Checks if two SKUFlags are not equal.
.. describe:: x | y, x |= y
Returns a SKUFlags instance with all enabled flags from
both x and y.
.. describe:: x & y, x &= y
Returns a SKUFlags instance with only flags enabled on
both x and y.
.. describe:: x ^ y, x ^= y
Returns a SKUFlags instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns a SKUFlags instance with all flags inverted from x.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def available(self):
""":class:`bool`: Returns ``True`` if the SKU is available for purchase."""
return 1 << 2
@flag_value
def guild_subscription(self):
""":class:`bool`: Returns ``True`` if the SKU is a guild subscription."""
return 1 << 7
@flag_value
def user_subscription(self):
""":class:`bool`: Returns ``True`` if the SKU is a user subscription."""
return 1 << 8
@fill_with_flags()
class EmbedFlags(BaseFlags):
r"""Wraps up the Discord Embed flags
.. versionadded:: 2.5
.. container:: operations
.. describe:: x == y
Checks if two EmbedFlags are equal.
.. describe:: x != y
Checks if two EmbedFlags are not equal.
.. describe:: x | y, x |= y
Returns an EmbedFlags instance with all enabled flags from
both x and y.
.. describe:: x ^ y, x ^= y
Returns an EmbedFlags instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns an EmbedFlags instance with all flags inverted from x.
.. describe:: hash(x)
Returns the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def contains_explicit_media(self):
""":class:`bool`: Returns ``True`` if the embed was flagged as sensitive content."""
return 1 << 4
@flag_value
def content_inventory_entry(self):
""":class:`bool`: Returns ``True`` if the embed is a reply to an activity card, and is no
longer displayed.
"""
return 1 << 5
class InviteFlags(BaseFlags):
r"""Wraps up the Discord Invite flags
.. versionadded:: 2.6
.. container:: operations
.. describe:: x == y
Checks if two InviteFlags are equal.
.. describe:: x != y
Checks if two InviteFlags are not equal.
.. describe:: x | y, x |= y
Returns a InviteFlags instance with all enabled flags from
both x and y.
.. describe:: x ^ y, x ^= y
Returns a InviteFlags instance with only flags enabled on
only one of x or y, not on both.
.. describe:: ~x
Returns a InviteFlags instance with all flags inverted from x.
.. describe:: hash(x)
Returns the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether any flag is set to ``True``.
Attributes
----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def guest(self):
""":class:`bool`: Returns ``True`` if this is a guest invite for a voice channel."""
return 1 << 0

167
discord/gateway.py

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
@ -32,9 +33,8 @@ import sys
import time
import threading
import traceback
import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
import aiohttp
import yarl
@ -59,7 +59,7 @@ if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
from .voice_state import VoiceConnectionState
class ReconnectWebSocket(Exception):
@ -132,11 +132,12 @@ class KeepAliveHandler(threading.Thread):
shard_id: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
daemon: bool = kwargs.pop('daemon', True)
name: str = kwargs.pop('name', f'keep-alive-handler:shard-{shard_id}')
super().__init__(*args, daemon=daemon, name=name, **kwargs)
self.ws: DiscordWebSocket = ws
self._main_thread_id: int = ws.thread_id
self.interval: Optional[float] = interval
self.daemon: bool = True
self.shard_id: Optional[int] = shard_id
self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
@ -211,8 +212,12 @@ class KeepAliveHandler(threading.Thread):
class VoiceKeepAliveHandler(KeepAliveHandler):
if TYPE_CHECKING:
ws: DiscordVoiceWebSocket
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
name: str = kwargs.pop('name', f'voice-keep-alive-handler:{id(self):#x}')
super().__init__(*args, name=name, **kwargs)
self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
@ -221,7 +226,10 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
def get_payload(self) -> Dict[str, Any]:
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000),
'd': {
't': int(time.time() * 1000),
'seq_ack': self.ws.seq_ack,
},
}
def ack(self) -> None:
@ -293,19 +301,19 @@ class DiscordWebSocket:
# fmt: off
DEFAULT_GATEWAY = yarl.URL('wss://gateway.discord.gg/')
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE = 3
VOICE_STATE = 4
VOICE_PING = 5
RESUME = 6
RECONNECT = 7
REQUEST_MEMBERS = 8
INVALIDATE_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE = 3
VOICE_STATE = 4
VOICE_PING = 5
RESUME = 6
RECONNECT = 7
REQUEST_MEMBERS = 8
INVALIDATE_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
# fmt: on
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
@ -323,8 +331,7 @@ class DiscordWebSocket:
# ws related stuff
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
@ -353,7 +360,7 @@ class DiscordWebSocket:
sequence: Optional[int] = None,
resume: bool = False,
encoding: str = 'json',
zlib: bool = True,
compress: bool = True,
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
@ -364,10 +371,12 @@ class DiscordWebSocket:
gateway = gateway or cls.DEFAULT_GATEWAY
if zlib:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
)
socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
@ -486,13 +495,11 @@ class DiscordWebSocket:
async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
msg = self._decompressor.decompress(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
# Received a partial gateway message
if msg is None:
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()
self.log_receive(msg)
msg = utils._from_json(msg)
@ -605,7 +612,10 @@ class DiscordWebSocket:
def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
# If the socket is closed remotely with 1000 and it's not our own explicit close
# then it's an improper close that should be handled and reconnected
is_improper_close = self._close_code is None and self.socket.close_code == 1000
return is_improper_close or code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self) -> None:
"""Polls for a DISPATCH event and handles the general gateway loop.
@ -622,8 +632,8 @@ class DiscordWebSocket:
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
raise msg.data
_log.debug('Received error %s', msg)
raise WebSocketClosure
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
_log.debug('Received %s', msg)
raise WebSocketClosure
@ -795,7 +805,7 @@ class DiscordVoiceWebSocket:
if TYPE_CHECKING:
thread_id: int
_connection: VoiceClient
_connection: VoiceConnectionState
gateway: str
_max_heartbeat_timeout: float
@ -825,9 +835,11 @@ class DiscordVoiceWebSocket:
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code: Optional[int] = None
self.secret_key: Optional[str] = None
self.secret_key: Optional[List[int]] = None
# defaulting to -1
self.seq_ack: int = -1
if hook:
self._hook = hook
self._hook = hook # type: ignore
async def _hook(self, *args: Any) -> None:
pass
@ -846,6 +858,7 @@ class DiscordVoiceWebSocket:
'token': state.token,
'server_id': str(state.server_id),
'session_id': state.session_id,
'seq_ack': self.seq_ack,
},
}
await self.send_as_json(payload)
@ -864,16 +877,23 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
@classmethod
async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
async def from_connection_state(
cls,
state: VoiceConnectionState,
*,
resume: bool = False,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
seq_ack: int = -1,
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
gateway = f'wss://{state.endpoint}/?v=8'
client = state.voice_client
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook)
ws.gateway = gateway
ws._connection = client
ws.seq_ack = seq_ack
ws._connection = state
ws._max_heartbeat_timeout = 60.0
ws.thread_id = threading.get_ident()
@ -884,7 +904,7 @@ class DiscordVoiceWebSocket:
return ws
async def select_protocol(self, ip: str, port: int, mode: int) -> None:
async def select_protocol(self, ip: str, port: int, mode: str) -> None:
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
@ -915,6 +935,7 @@ class DiscordVoiceWebSocket:
'd': {
'speaking': int(state),
'delay': 0,
'ssrc': self._connection.ssrc,
},
}
@ -924,6 +945,7 @@ class DiscordVoiceWebSocket:
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
data = msg['d'] # According to Discord this key is always given
self.seq_ack = msg.get('seq', self.seq_ack) # this key could not be given
if op == self.READY:
await self.initial_connection(data)
@ -948,30 +970,50 @@ class DiscordVoiceWebSocket:
state.voice_port = data['port']
state.endpoint_ip = data['ip']
packet = bytearray(70)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70)
_log.debug('received packet in initial_connection: %s', recv)
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
_log.debug('Connecting to voice socket')
await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))
state.ip, state.port = await self.discover_ip()
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes))
_log.debug('received supported encryption modes: %s', ', '.join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug('selected the voice protocol for use (%s)', mode)
async def discover_ip(self) -> Tuple[str, int]:
state = self._connection
packet = bytearray(74)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
_log.debug('Sending ip discovery packet')
await self.loop.sock_sendall(state.socket, packet)
fut: asyncio.Future[bytes] = self.loop.create_future()
def get_ip_packet(data: bytes):
if data[1] == 0x02 and len(data) == 74:
self.loop.call_soon_threadsafe(fut.set_result, data)
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
state.add_socket_listener(get_ip_packet)
recv = await fut
_log.debug('Received ip discovery packet: %s', recv)
# the ip is ascii starting at the 8th byte and ending at the first null
ip_start = 8
ip_end = recv.index(0, ip_start)
ip = recv[ip_start:ip_end].decode('ascii')
port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', ip, port)
return ip, port
@property
def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
@ -990,7 +1032,10 @@ class DiscordVoiceWebSocket:
async def load_secret_key(self, data: Dict[str, Any]) -> None:
_log.debug('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data['secret_key']
await self.speak()
# Send a speak command with the "not speaking" state.
# This also tells Discord our SSRC value, which Discord requires before
# sending any voice data (and is the real reason why we call this here).
await self.speak(SpeakingState.none)
async def poll_event(self) -> None:
@ -999,10 +1044,10 @@ class DiscordVoiceWebSocket:
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code: int = 1000) -> None:

1137
discord/guild.py

File diff suppressed because it is too large

446
discord/http.py

@ -67,7 +67,7 @@ if TYPE_CHECKING:
from .embeds import Embed
from .message import Attachment
from .flags import MessageFlags
from .enums import AuditLogAction
from .poll import Poll
from .types import (
appinfo,
@ -90,9 +90,14 @@ if TYPE_CHECKING:
scheduled_event,
sticker,
welcome_screen,
onboarding,
sku,
poll,
voice,
soundboard,
subscription,
)
from .types.snowflake import Snowflake, SnowflakeList
from .types.gateway import SessionStartLimit
from types import TracebackType
@ -153,6 +158,8 @@ def handle_message_parameters(
mention_author: Optional[bool] = None,
thread_name: str = MISSING,
channel_payload: Dict[str, Any] = MISSING,
applied_tags: Optional[SnowflakeList] = MISSING,
poll: Optional[Poll] = MISSING,
) -> MultipartParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
@ -191,6 +198,7 @@ def handle_message_parameters(
if nonce is not None:
payload['nonce'] = str(nonce)
payload['enforce_nonce'] = True
if message_reference is not MISSING:
payload['message_reference'] = message_reference
@ -243,12 +251,21 @@ def handle_message_parameters(
payload['attachments'] = attachments_payload
if applied_tags is not MISSING:
if applied_tags is not None:
payload['applied_tags'] = applied_tags
else:
payload['applied_tags'] = []
if channel_payload is not MISSING:
payload = {
'message': payload,
}
payload.update(channel_payload)
if poll not in (MISSING, None):
payload['poll'] = poll._to_dict() # type: ignore
multipart = []
if files:
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
@ -293,7 +310,7 @@ class Route:
self.metadata: Optional[str] = metadata
url = self.BASE + self.path
if parameters:
url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()})
url = url.format_map({k: _uriquote(v, safe='') if isinstance(v, str) else v for k, v in parameters.items()})
self.url: str = url
# major parameters:
@ -444,7 +461,12 @@ class Ratelimit:
future = self._loop.create_future()
self._pending_requests.append(future)
try:
await future
while not future.done():
# 30 matches the smallest allowed max_ratelimit_timeout
max_wait_time = self.expires - self._loop.time() if self.expires else 30
await asyncio.wait([future], timeout=max_wait_time)
if not future.done():
await self._refresh()
except:
future.cancel()
if self.remaining > 0 and not future.cancelled():
@ -651,14 +673,13 @@ class HTTPClient:
_log.debug(fmt, route_key, bucket_hash, discord_hash)
self._bucket_hashes[route_key] = discord_hash
recalculated_key = discord_hash + route.major_parameters
self._buckets[recalculated_key] = ratelimit
self._buckets[f'{discord_hash}:{route.major_parameters}'] = ratelimit
self._buckets.pop(key, None)
elif route_key not in self._bucket_hashes:
fmt = '%s has found its initial rate limit bucket hash (%s).'
_log.debug(fmt, route_key, discord_hash)
self._bucket_hashes[route_key] = discord_hash
self._buckets[discord_hash + route.major_parameters] = ratelimit
self._buckets[f'{discord_hash}:{route.major_parameters}'] = ratelimit
if has_ratelimit_headers:
if response.status != 429:
@ -762,7 +783,15 @@ class HTTPClient:
raise RuntimeError('Unreachable code in HTTP handling')
async def get_from_cdn(self, url: str) -> bytes:
async with self.__session.get(url) as resp:
kwargs = {}
# Proxy support
if self.proxy is not None:
kwargs['proxy'] = self.proxy
if self.proxy_auth is not None:
kwargs['proxy_auth'] = self.proxy_auth
async with self.__session.get(url, **kwargs) as resp:
if resp.status == 200:
return await resp.read()
elif resp.status == 404:
@ -791,6 +820,7 @@ class HTTPClient:
connector=self.connector,
ws_response_class=DiscordClientWebSocketResponse,
trace_configs=None if self.http_trace is None else [self.http_trace],
cookie_jar=aiohttp.DummyCookieJar(),
)
self._global_over = asyncio.Event()
self._global_over.set()
@ -927,6 +957,7 @@ class HTTPClient:
emoji: str,
limit: int,
after: Optional[Snowflake] = None,
type: Optional[message.ReactionType] = None,
) -> Response[List[user.User]]:
r = Route(
'GET',
@ -941,6 +972,10 @@ class HTTPClient:
}
if after:
params['after'] = after
if type is not None:
params['type'] = type
return self.request(r, params=params)
def clear_reactions(self, channel_id: Snowflake, message_id: Snowflake) -> Response[None]:
@ -1047,6 +1082,20 @@ class HTTPClient:
r = Route('DELETE', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id)
return self.request(r, reason=reason)
def bulk_ban(
self,
guild_id: Snowflake,
user_ids: List[Snowflake],
delete_message_seconds: int = 86400,
reason: Optional[str] = None,
) -> Response[guild.BulkBanUserResponse]:
r = Route('POST', '/guilds/{guild_id}/bulk-ban', guild_id=guild_id)
payload = {
'user_ids': user_ids,
'delete_message_seconds': delete_message_seconds,
}
return self.request(r, json=payload, reason=reason)
def guild_voice_state(
self,
user_id: Snowflake,
@ -1115,6 +1164,12 @@ class HTTPClient:
r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id)
return self.request(r, json=fields, reason=reason)
def get_my_voice_state(self, guild_id: Snowflake) -> Response[voice.GuildVoiceState]:
return self.request(Route('GET', '/guilds/{guild_id}/voice-states/@me', guild_id=guild_id))
def get_voice_state(self, guild_id: Snowflake, user_id: Snowflake) -> Response[voice.GuildVoiceState]:
return self.request(Route('GET', '/guilds/{guild_id}/voice-states/{user_id}', guild_id=guild_id, user_id=user_id))
# Channel management
def edit_channel(
@ -1149,11 +1204,19 @@ class HTTPClient:
'available_tags',
'applied_tags',
'default_forum_layout',
'default_sort_order',
)
payload = {k: v for k, v in options.items() if k in valid_keys}
return self.request(r, reason=reason, json=payload)
def edit_voice_channel_status(
self, status: Optional[str], *, channel_id: int, reason: Optional[str] = None
) -> Response[None]:
r = Route('PUT', '/channels/{channel_id}/voice-status', channel_id=channel_id)
payload = {'status': status}
return self.request(r, reason=reason, json=payload)
def bulk_channel_update(
self,
guild_id: Snowflake,
@ -1190,6 +1253,9 @@ class HTTPClient:
'video_quality_mode',
'default_auto_archive_duration',
'default_thread_rate_limit_per_user',
'default_sort_order',
'default_reaction_emoji',
'default_forum_layout',
'available_tags',
)
payload.update({k: v for k, v in options.items() if k in valid_keys and v is not None})
@ -1370,9 +1436,11 @@ class HTTPClient:
limit: int,
before: Optional[Snowflake] = None,
after: Optional[Snowflake] = None,
with_counts: bool = True,
) -> Response[List[guild.Guild]]:
params: Dict[str, Any] = {
'limit': limit,
'with_counts': int(with_counts),
}
if before:
@ -1389,6 +1457,9 @@ class HTTPClient:
params = {'with_counts': int(with_counts)}
return self.request(Route('GET', '/guilds/{guild_id}', guild_id=guild_id), params=params)
def get_guild_preview(self, guild_id: Snowflake) -> Response[guild.GuildPreview]:
return self.request(Route('GET', '/guilds/{guild_id}/preview', guild_id=guild_id))
def delete_guild(self, guild_id: Snowflake) -> Response[None]:
return self.request(Route('DELETE', '/guilds/{guild_id}', guild_id=guild_id))
@ -1423,12 +1494,19 @@ class HTTPClient:
'public_updates_channel_id',
'preferred_locale',
'premium_progress_bar_enabled',
'safety_alerts_channel_id',
)
payload = {k: v for k, v in fields.items() if k in valid_keys}
return self.request(Route('PATCH', '/guilds/{guild_id}', guild_id=guild_id), json=payload, reason=reason)
def edit_guild_mfa_level(
self, guild_id: Snowflake, *, mfa_level: int, reason: Optional[str] = None
) -> Response[guild.GuildMFALevel]:
payload = {'level': mfa_level}
return self.request(Route('POST', '/guilds/{guild_id}/mfa', guild_id=guild_id), json=payload, reason=reason)
def get_template(self, code: str) -> Response[template.Template]:
return self.request(Route('GET', '/guilds/templates/{code}', code=code))
@ -1558,6 +1636,9 @@ class HTTPClient:
def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]:
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
def get_sticker_pack(self, sticker_pack_id: Snowflake) -> Response[sticker.StickerPack]:
return self.request(Route('GET', '/sticker-packs/{sticker_pack_id}', sticker_pack_id=sticker_pack_id))
def list_premium_sticker_packs(self) -> Response[sticker.ListPremiumStickerPacks]:
return self.request(Route('GET', '/sticker-packs'))
@ -1713,12 +1794,12 @@ class HTTPClient:
before: Optional[Snowflake] = None,
after: Optional[Snowflake] = None,
user_id: Optional[Snowflake] = None,
action_type: Optional[AuditLogAction] = None,
action_type: Optional[audit_log.AuditLogEvent] = None,
) -> Response[audit_log.AuditLog]:
params: Dict[str, Any] = {'limit': limit}
if before:
params['before'] = before
if after:
if after is not None:
params['after'] = after
if user_id:
params['user_id'] = user_id
@ -1736,8 +1817,8 @@ class HTTPClient:
) -> Response[widget.WidgetSettings]:
return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason)
def get_guild_onboarding(self, guild_id: Snowflake) -> Response[onboarding.Onboarding]:
return self.request(Route('GET', '/guilds/{guild_id}/onboarding', guild_id=guild_id))
def edit_incident_actions(self, guild_id: Snowflake, payload: guild.IncidentData) -> Response[guild.IncidentData]:
return self.request(Route('PUT', '/guilds/{guild_id}/incident-actions', guild_id=guild_id), json=payload)
# Invite management
@ -1753,6 +1834,7 @@ class HTTPClient:
target_type: Optional[invite.InviteTargetType] = None,
target_user_id: Optional[Snowflake] = None,
target_application_id: Optional[Snowflake] = None,
flags: Optional[int] = None,
) -> Response[invite.Invite]:
r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id)
payload = {
@ -1771,6 +1853,9 @@ class HTTPClient:
if target_application_id:
payload['target_application_id'] = str(target_application_id)
if flags:
payload['flags'] = flags
return self.request(r, reason=reason, json=payload)
def get_invite(
@ -1797,7 +1882,7 @@ class HTTPClient:
def invites_from_channel(self, channel_id: Snowflake) -> Response[List[invite.Invite]]:
return self.request(Route('GET', '/channels/{channel_id}/invites', channel_id=channel_id))
def delete_invite(self, invite_id: str, *, reason: Optional[str] = None) -> Response[None]:
def delete_invite(self, invite_id: str, *, reason: Optional[str] = None) -> Response[invite.Invite]:
return self.request(Route('DELETE', '/invites/{invite_id}', invite_id=invite_id), reason=reason)
# Role management
@ -1805,6 +1890,9 @@ class HTTPClient:
def get_roles(self, guild_id: Snowflake) -> Response[List[role.Role]]:
return self.request(Route('GET', '/guilds/{guild_id}/roles', guild_id=guild_id))
def get_role(self, guild_id: Snowflake, role_id: Snowflake) -> Response[role.Role]:
return self.request(Route('GET', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id))
def edit_role(
self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any
) -> Response[role.Role]:
@ -1907,6 +1995,8 @@ class HTTPClient:
'channel_id',
'topic',
'privacy_level',
'send_start_notification',
'guild_scheduled_event_id',
)
payload = {k: v for k, v in payload.items() if k in valid_keys}
@ -2363,33 +2453,329 @@ class HTTPClient:
reason=reason,
)
# Misc
# SKU
def get_skus(self, application_id: Snowflake) -> Response[List[sku.SKU]]:
return self.request(Route('GET', '/applications/{application_id}/skus', application_id=application_id))
def get_entitlements(
self,
application_id: Snowflake,
user_id: Optional[Snowflake] = None,
sku_ids: Optional[SnowflakeList] = None,
before: Optional[Snowflake] = None,
after: Optional[Snowflake] = None,
limit: Optional[int] = None,
guild_id: Optional[Snowflake] = None,
exclude_ended: Optional[bool] = None,
exclude_deleted: Optional[bool] = None,
) -> Response[List[sku.Entitlement]]:
params: Dict[str, Any] = {}
if user_id is not None:
params['user_id'] = user_id
if sku_ids is not None:
params['sku_ids'] = ','.join(map(str, sku_ids))
if before is not None:
params['before'] = before
if after is not None:
params['after'] = after
if limit is not None:
params['limit'] = limit
if guild_id is not None:
params['guild_id'] = guild_id
if exclude_ended is not None:
params['exclude_ended'] = int(exclude_ended)
if exclude_deleted is not None:
params['exclude_deleted'] = int(exclude_deleted)
return self.request(
Route('GET', '/applications/{application_id}/entitlements', application_id=application_id), params=params
)
def get_entitlement(self, application_id: Snowflake, entitlement_id: Snowflake) -> Response[sku.Entitlement]:
return self.request(
Route(
'GET',
'/applications/{application_id}/entitlements/{entitlement_id}',
application_id=application_id,
entitlement_id=entitlement_id,
),
)
def consume_entitlement(self, application_id: Snowflake, entitlement_id: Snowflake) -> Response[None]:
return self.request(
Route(
'POST',
'/applications/{application_id}/entitlements/{entitlement_id}/consume',
application_id=application_id,
entitlement_id=entitlement_id,
),
)
def create_entitlement(
self, application_id: Snowflake, sku_id: Snowflake, owner_id: Snowflake, owner_type: sku.EntitlementOwnerType
) -> Response[sku.Entitlement]:
payload = {
'sku_id': sku_id,
'owner_id': owner_id,
'owner_type': owner_type,
}
return self.request(
Route(
'POST',
'/applications/{application_id}/entitlements',
application_id=application_id,
),
json=payload,
)
def delete_entitlement(self, application_id: Snowflake, entitlement_id: Snowflake) -> Response[None]:
return self.request(
Route(
'DELETE',
'/applications/{application_id}/entitlements/{entitlement_id}',
application_id=application_id,
entitlement_id=entitlement_id,
),
)
# Soundboard
def get_soundboard_default_sounds(self) -> Response[List[soundboard.SoundboardDefaultSound]]:
return self.request(Route('GET', '/soundboard-default-sounds'))
def get_soundboard_sound(self, guild_id: Snowflake, sound_id: Snowflake) -> Response[soundboard.SoundboardSound]:
return self.request(
Route('GET', '/guilds/{guild_id}/soundboard-sounds/{sound_id}', guild_id=guild_id, sound_id=sound_id)
)
def get_soundboard_sounds(self, guild_id: Snowflake) -> Response[Dict[str, List[soundboard.SoundboardSound]]]:
return self.request(Route('GET', '/guilds/{guild_id}/soundboard-sounds', guild_id=guild_id))
def create_soundboard_sound(
self, guild_id: Snowflake, *, reason: Optional[str], **payload: Any
) -> Response[soundboard.SoundboardSound]:
valid_keys = (
'name',
'sound',
'volume',
'emoji_id',
'emoji_name',
)
payload = {k: v for k, v in payload.items() if k in valid_keys and v is not None}
return self.request(
Route('POST', '/guilds/{guild_id}/soundboard-sounds', guild_id=guild_id), json=payload, reason=reason
)
def edit_soundboard_sound(
self, guild_id: Snowflake, sound_id: Snowflake, *, reason: Optional[str], **payload: Any
) -> Response[soundboard.SoundboardSound]:
valid_keys = (
'name',
'volume',
'emoji_id',
'emoji_name',
)
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(
Route(
'PATCH',
'/guilds/{guild_id}/soundboard-sounds/{sound_id}',
guild_id=guild_id,
sound_id=sound_id,
),
json=payload,
reason=reason,
)
def delete_soundboard_sound(self, guild_id: Snowflake, sound_id: Snowflake, *, reason: Optional[str]) -> Response[None]:
return self.request(
Route(
'DELETE',
'/guilds/{guild_id}/soundboard-sounds/{sound_id}',
guild_id=guild_id,
sound_id=sound_id,
),
reason=reason,
)
def send_soundboard_sound(self, channel_id: Snowflake, **payload: Any) -> Response[None]:
valid_keys = ('sound_id', 'source_guild_id')
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(
(Route('POST', '/channels/{channel_id}/send-soundboard-sound', channel_id=channel_id)), json=payload
)
# Application
def application_info(self) -> Response[appinfo.AppInfo]:
return self.request(Route('GET', '/oauth2/applications/@me'))
async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str:
try:
data = await self.request(Route('GET', '/gateway'))
except HTTPException as exc:
raise GatewayNotFound() from exc
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return value.format(data['url'], encoding, INTERNAL_API_VERSION)
def edit_application_info(self, *, reason: Optional[str], payload: Any) -> Response[appinfo.AppInfo]:
valid_keys = (
'custom_install_url',
'description',
'role_connections_verification_url',
'install_params',
'flags',
'icon',
'cover_image',
'interactions_endpoint_url ',
'tags',
'integration_types_config',
)
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(Route('PATCH', '/applications/@me'), json=payload, reason=reason)
def get_application_emojis(self, application_id: Snowflake) -> Response[appinfo.ListAppEmojis]:
return self.request(Route('GET', '/applications/{application_id}/emojis', application_id=application_id))
def get_application_emoji(self, application_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]:
return self.request(
Route(
'GET', '/applications/{application_id}/emojis/{emoji_id}', application_id=application_id, emoji_id=emoji_id
)
)
def create_application_emoji(
self,
application_id: Snowflake,
name: str,
image: str,
) -> Response[emoji.Emoji]:
payload = {
'name': name,
'image': image,
}
return self.request(
Route('POST', '/applications/{application_id}/emojis', application_id=application_id), json=payload
)
def edit_application_emoji(
self,
application_id: Snowflake,
emoji_id: Snowflake,
*,
payload: Dict[str, Any],
) -> Response[emoji.Emoji]:
r = Route(
'PATCH', '/applications/{application_id}/emojis/{emoji_id}', application_id=application_id, emoji_id=emoji_id
)
return self.request(r, json=payload)
def delete_application_emoji(
self,
application_id: Snowflake,
emoji_id: Snowflake,
) -> Response[None]:
return self.request(
Route(
'DELETE',
'/applications/{application_id}/emojis/{emoji_id}',
application_id=application_id,
emoji_id=emoji_id,
)
)
# Poll
async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]:
def get_poll_answer_voters(
self,
channel_id: Snowflake,
message_id: Snowflake,
answer_id: Snowflake,
after: Optional[Snowflake] = None,
limit: Optional[int] = None,
) -> Response[poll.PollAnswerVoters]:
params = {}
if after:
params['after'] = int(after)
if limit is not None:
params['limit'] = limit
return self.request(
Route(
'GET',
'/channels/{channel_id}/polls/{message_id}/answers/{answer_id}',
channel_id=channel_id,
message_id=message_id,
answer_id=answer_id,
),
params=params,
)
def end_poll(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]:
return self.request(
Route(
'POST',
'/channels/{channel_id}/polls/{message_id}/expire',
channel_id=channel_id,
message_id=message_id,
)
)
# Subscriptions
def list_sku_subscriptions(
self,
sku_id: Snowflake,
before: Optional[Snowflake] = None,
after: Optional[Snowflake] = None,
limit: Optional[int] = None,
user_id: Optional[Snowflake] = None,
) -> Response[List[subscription.Subscription]]:
params = {}
if before is not None:
params['before'] = before
if after is not None:
params['after'] = after
if limit is not None:
params['limit'] = limit
if user_id is not None:
params['user_id'] = user_id
return self.request(
Route(
'GET',
'/skus/{sku_id}/subscriptions',
sku_id=sku_id,
),
params=params,
)
def get_sku_subscription(self, sku_id: Snowflake, subscription_id: Snowflake) -> Response[subscription.Subscription]:
return self.request(
Route(
'GET',
'/skus/{sku_id}/subscriptions/{subscription_id}',
sku_id=sku_id,
subscription_id=subscription_id,
)
)
# Misc
async def get_bot_gateway(self) -> Tuple[int, str, SessionStartLimit]:
try:
data = await self.request(Route('GET', '/gateway/bot'))
except HTTPException as exc:
raise GatewayNotFound() from exc
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION)
return data['shards'], data['url'], data['session_start_limit']
def get_user(self, user_id: Snowflake) -> Response[user.User]:
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))

422
discord/interactions.py

@ -25,7 +25,9 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Dict, Optional, Generic, TYPE_CHECKING, Sequence, Tuple, Union
import logging
from typing import Any, Dict, Optional, Generic, TYPE_CHECKING, Sequence, Tuple, Union, List
import asyncio
import datetime
@ -33,8 +35,9 @@ from . import utils
from .enums import try_enum, Locale, InteractionType, InteractionResponseType
from .errors import InteractionResponded, HTTPException, ClientException, DiscordException
from .flags import MessageFlags
from .channel import PartialMessageable, ChannelType
from .channel import ChannelType
from ._types import ClientT
from .sku import Entitlement
from .user import User
from .member import Member
@ -42,13 +45,17 @@ from .message import Message, Attachment
from .permissions import Permissions
from .http import handle_message_parameters
from .webhook.async_ import async_context, Webhook, interaction_response_params, interaction_message_response_params
from .app_commands.installs import AppCommandContext
from .app_commands.namespace import Namespace
from .app_commands.translator import locale_str, TranslationContext, TranslationContextLocation
from .channel import _threaded_channel_factory
__all__ = (
'Interaction',
'InteractionMessage',
'InteractionResponse',
'InteractionCallbackResponse',
'InteractionCallbackActivityInstance',
)
if TYPE_CHECKING:
@ -56,10 +63,13 @@ if TYPE_CHECKING:
Interaction as InteractionPayload,
InteractionData,
ApplicationCommandInteractionData,
InteractionCallback as InteractionCallbackPayload,
InteractionCallbackActivity as InteractionCallbackActivityPayload,
)
from .types.webhook import (
Webhook as WebhookPayload,
)
from .types.snowflake import Snowflake
from .guild import Guild
from .state import ConnectionState
from .file import File
@ -69,12 +79,24 @@ if TYPE_CHECKING:
from .ui.view import View
from .app_commands.models import Choice, ChoiceT
from .ui.modal import Modal
from .channel import VoiceChannel, StageChannel, TextChannel, ForumChannel, CategoryChannel
from .channel import VoiceChannel, StageChannel, TextChannel, ForumChannel, CategoryChannel, DMChannel, GroupChannel
from .threads import Thread
from .app_commands.commands import Command, ContextMenu
from .poll import Poll
InteractionChannel = Union[
VoiceChannel, StageChannel, TextChannel, ForumChannel, CategoryChannel, Thread, PartialMessageable
VoiceChannel,
StageChannel,
TextChannel,
ForumChannel,
CategoryChannel,
Thread,
DMChannel,
GroupChannel,
]
InteractionCallbackResource = Union[
"InteractionMessage",
"InteractionCallbackActivityInstance",
]
MISSING: Any = utils.MISSING
@ -96,8 +118,14 @@ class Interaction(Generic[ClientT]):
The interaction type.
guild_id: Optional[:class:`int`]
The guild ID the interaction was sent from.
channel_id: Optional[:class:`int`]
The channel ID the interaction was sent from.
channel: Optional[Union[:class:`abc.GuildChannel`, :class:`abc.PrivateChannel`, :class:`Thread`]]
The channel the interaction was sent from.
Note that due to a Discord limitation, if sent from a DM channel :attr:`~DMChannel.recipient` is ``None``.
entitlement_sku_ids: List[:class:`int`]
The entitlement SKU IDs that the user has.
entitlements: List[:class:`Entitlement`]
The entitlements that the guild or user has.
application_id: :class:`int`
The application ID that the interaction was for.
user: Union[:class:`User`, :class:`Member`]
@ -122,13 +150,20 @@ class Interaction(Generic[ClientT]):
command_failed: :class:`bool`
Whether the command associated with this interaction failed to execute.
This includes checks and execution.
context: :class:`.AppCommandContext`
The context of the interaction.
.. versionadded:: 2.4
filesize_limit: int
The maximum number of bytes a file can have when responding to this interaction.
.. versionadded:: 2.6
"""
__slots__: Tuple[str, ...] = (
'id',
'type',
'guild_id',
'channel_id',
'data',
'application_id',
'message',
@ -139,6 +174,11 @@ class Interaction(Generic[ClientT]):
'guild_locale',
'extras',
'command_failed',
'entitlement_sku_ids',
'entitlements',
'context',
'filesize_limit',
'_integration_owners',
'_permissions',
'_app_permissions',
'_state',
@ -148,7 +188,7 @@ class Interaction(Generic[ClientT]):
'_original_response',
'_cs_response',
'_cs_followup',
'_cs_channel',
'channel',
'_cs_namespace',
'_cs_command',
)
@ -165,23 +205,61 @@ class Interaction(Generic[ClientT]):
self.command_failed: bool = False
self._from_data(data)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} type={self.type!r} guild_id={self.guild_id!r} user={self.user!r}>'
def _from_data(self, data: InteractionPayload):
self.id: int = int(data['id'])
self.type: InteractionType = try_enum(InteractionType, data['type'])
self.data: Optional[InteractionData] = data.get('data')
self.token: str = data['token']
self.version: int = data['version']
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.channel: Optional[InteractionChannel] = None
self.application_id: int = int(data['application_id'])
self.entitlement_sku_ids: List[int] = [int(x) for x in data.get('entitlement_skus', []) or []]
self.entitlements: List[Entitlement] = [Entitlement(self._state, x) for x in data.get('entitlements', [])]
self.filesize_limit: int = data['attachment_size_limit']
# This is not entirely useful currently, unsure how to expose it in a way that it is.
self._integration_owners: Dict[int, Snowflake] = {
int(k): int(v) for k, v in data.get('authorizing_integration_owners', {}).items()
}
try:
value = data['context'] # pyright: ignore[reportTypedDictNotRequiredAccess]
self.context = AppCommandContext._from_value([value])
except KeyError:
self.context = AppCommandContext()
self.locale: Locale = try_enum(Locale, data.get('locale', 'en-US'))
self.guild_locale: Optional[Locale]
try:
self.guild_locale = try_enum(Locale, data['guild_locale'])
self.guild_locale = try_enum(Locale, data['guild_locale']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.guild_locale = None
guild = None
if self.guild_id:
# The data type is a TypedDict but it doesn't narrow to Dict[str, Any] properly
guild = self._state._get_or_create_unavailable_guild(self.guild_id, data=data.get('guild')) # type: ignore
if guild.me is None and self._client.user is not None:
guild._add_member(Member._from_client_user(user=self._client.user, guild=guild, state=self._state))
raw_channel = data.get('channel', {})
channel_id = utils._get_as_snowflake(raw_channel, 'id')
if channel_id is not None and guild is not None:
self.channel = guild and guild._resolve_channel(channel_id)
raw_ch_type = raw_channel.get('type')
if self.channel is None and raw_ch_type is not None:
factory, ch_type = _threaded_channel_factory(raw_ch_type) # type is never None
if factory is None:
logging.info('Unknown channel type {type} for channel ID {id}.'.format_map(raw_channel))
else:
if ch_type in (ChannelType.group, ChannelType.private):
self.channel = factory(me=self._client.user, data=raw_channel, state=self._state) # type: ignore
elif guild is not None:
self.channel = factory(guild=guild, state=self._state, data=raw_channel) # type: ignore
self.message: Optional[Message]
try:
# The channel and message payloads are mismatched yet handled properly at runtime
@ -193,8 +271,11 @@ class Interaction(Generic[ClientT]):
self._permissions: int = 0
self._app_permissions: int = int(data.get('app_permissions', 0))
if self.guild_id:
guild = self._state._get_or_create_unavailable_guild(self.guild_id)
if guild is not None:
# Upgrade Message.guild in case it's missing with partial guild data
if self.message is not None and self.message.guild is None:
self.message.guild = guild
try:
member = data['member'] # type: ignore # The key is optional and handled
except KeyError:
@ -220,23 +301,15 @@ class Interaction(Generic[ClientT]):
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id)
@utils.cached_slot_property('_cs_channel')
def channel(self) -> Optional[InteractionChannel]:
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
# The user.guild attribute is set in __init__ to the fallback guild if available
# Therefore, we can use that instead of recreating it every time this property is
# accessed
return (self._state and self._state._get_guild(self.guild_id)) or getattr(self.user, 'guild', None)
Note that due to a Discord limitation, DM channels are not resolved since there is
no data to complete them. These are :class:`PartialMessageable` instead.
"""
guild = self.guild
channel = guild and guild._resolve_channel(self.channel_id)
if channel is None:
if self.channel_id is not None:
type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, guild_id=self.guild_id, id=self.channel_id, type=type)
return None
return channel
@property
def channel_id(self) -> Optional[int]:
"""Optional[:class:`int`]: The ID of the channel the interaction was sent from."""
return self.channel.id if self.channel is not None else None
@property
def permissions(self) -> Permissions:
@ -336,6 +409,22 @@ class Interaction(Generic[ClientT]):
""":class:`bool`: Returns ``True`` if the interaction is expired."""
return utils.utcnow() >= self.expires_at
def is_guild_integration(self) -> bool:
""":class:`bool`: Returns ``True`` if the interaction is a guild integration.
.. versionadded:: 2.4
"""
if self.guild_id:
return self.guild_id == self._integration_owners.get(0)
return False
def is_user_integration(self) -> bool:
""":class:`bool`: Returns ``True`` if the interaction is a user integration.
.. versionadded:: 2.4
"""
return self.user.id == self._integration_owners.get(1)
async def original_response(self) -> InteractionMessage:
"""|coro|
@ -395,6 +484,7 @@ class Interaction(Generic[ClientT]):
attachments: Sequence[Union[Attachment, File]] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
poll: Poll = MISSING,
) -> InteractionMessage:
"""|coro|
@ -429,6 +519,14 @@ class Interaction(Generic[ClientT]):
view: Optional[:class:`~discord.ui.View`]
The updated view to update this message with. If ``None`` is passed then
the view is removed.
poll: :class:`Poll`
The poll to create when editing the message.
.. versionadded:: 2.5
.. note::
This is only accepted when the response type is :attr:`InteractionResponseType.deferred_channel_message`.
Raises
-------
@ -458,6 +556,7 @@ class Interaction(Generic[ClientT]):
view=view,
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
poll=poll,
) as params:
adapter = async_context.get()
http = self._state.http
@ -550,6 +649,109 @@ class Interaction(Generic[ClientT]):
return await translator.translate(string, locale=locale, context=context)
class InteractionCallbackActivityInstance:
"""Represents an activity instance launched as an interaction response.
.. versionadded:: 2.5
Attributes
----------
id: :class:`str`
The activity instance ID.
"""
__slots__ = ('id',)
def __init__(self, data: InteractionCallbackActivityPayload) -> None:
self.id: str = data['id']
class InteractionCallbackResponse(Generic[ClientT]):
"""Represents an interaction response callback.
.. versionadded:: 2.5
Attributes
----------
id: :class:`int`
The interaction ID.
type: :class:`InteractionResponseType`
The interaction callback response type.
resource: Optional[Union[:class:`InteractionMessage`, :class:`InteractionCallbackActivityInstance`]]
The resource that the interaction response created. If a message was sent, this will be
a :class:`InteractionMessage`. If an activity was launched this will be a
:class:`InteractionCallbackActivityInstance`. In any other case, this will be ``None``.
message_id: Optional[:class:`int`]
The message ID of the resource. Only available if the resource is a :class:`InteractionMessage`.
activity_id: Optional[:class:`str`]
The activity ID of the resource. Only available if the resource is a :class:`InteractionCallbackActivityInstance`.
"""
__slots__ = (
'_state',
'_parent',
'type',
'id',
'_thinking',
'_ephemeral',
'message_id',
'activity_id',
'resource',
)
def __init__(
self,
*,
data: InteractionCallbackPayload,
parent: Interaction[ClientT],
state: ConnectionState,
type: InteractionResponseType,
) -> None:
self._state: ConnectionState = state
self._parent: Interaction[ClientT] = parent
self.type: InteractionResponseType = type
self._update(data)
def __repr__(self) -> str:
return f'<InteractionCallbackResponse id={self.id} type={self.type!r}>'
def _update(self, data: InteractionCallbackPayload) -> None:
interaction = data['interaction']
self.id: int = int(interaction['id'])
self._thinking: bool = interaction.get('response_message_loading', False)
self._ephemeral: bool = interaction.get('response_message_ephemeral', False)
self.message_id: Optional[int] = utils._get_as_snowflake(interaction, 'response_message_id')
self.activity_id: Optional[str] = interaction.get('activity_instance_id')
self.resource: Optional[InteractionCallbackResource] = None
resource = data.get('resource')
if resource is not None:
self.type = try_enum(InteractionResponseType, resource['type'])
message = resource.get('message')
activity_instance = resource.get('activity_instance')
if message is not None:
self.resource = InteractionMessage(
state=_InteractionMessageState(self._parent, self._state), # pyright: ignore[reportArgumentType]
channel=self._parent.channel, # type: ignore # channel should be the correct type here
data=message,
)
elif activity_instance is not None:
self.resource = InteractionCallbackActivityInstance(activity_instance)
def is_thinking(self) -> bool:
""":class:`bool`: Whether the response was a thinking defer."""
return self._thinking
def is_ephemeral(self) -> bool:
""":class:`bool`: Whether the response was ephemeral."""
return self._ephemeral
class InteractionResponse(Generic[ClientT]):
"""Represents a Discord interaction response.
@ -579,7 +781,12 @@ class InteractionResponse(Generic[ClientT]):
""":class:`InteractionResponseType`: The type of response that was sent, ``None`` if response is not done."""
return self._response_type
async def defer(self, *, ephemeral: bool = False, thinking: bool = False) -> None:
async def defer(
self,
*,
ephemeral: bool = False,
thinking: bool = False,
) -> Optional[InteractionCallbackResponse[ClientT]]:
"""|coro|
Defers the interaction response.
@ -593,6 +800,9 @@ class InteractionResponse(Generic[ClientT]):
- :attr:`InteractionType.component`
- :attr:`InteractionType.modal_submit`
.. versionchanged:: 2.5
This now returns a :class:`InteractionCallbackResponse` instance.
Parameters
-----------
ephemeral: :class:`bool`
@ -611,6 +821,11 @@ class InteractionResponse(Generic[ClientT]):
Deferring the interaction failed.
InteractionResponded
This interaction has already been responded to before.
Returns
-------
Optional[:class:`InteractionCallbackResponse`]
The interaction callback resource, or ``None``.
"""
if self._response_type:
raise InteractionResponded(self._parent)
@ -635,7 +850,7 @@ class InteractionResponse(Generic[ClientT]):
adapter = async_context.get()
params = interaction_response_params(type=defer_type, data=data)
http = parent._state.http
await adapter.create_interaction_response(
response = await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
@ -644,6 +859,12 @@ class InteractionResponse(Generic[ClientT]):
params=params,
)
self._response_type = InteractionResponseType(defer_type)
return InteractionCallbackResponse(
data=response,
parent=self._parent,
state=self._parent._state,
type=self._response_type,
)
async def pong(self) -> None:
"""|coro|
@ -692,11 +913,15 @@ class InteractionResponse(Generic[ClientT]):
suppress_embeds: bool = False,
silent: bool = False,
delete_after: Optional[float] = None,
) -> None:
poll: Poll = MISSING,
) -> InteractionCallbackResponse[ClientT]:
"""|coro|
Responds to this interaction by sending a message.
.. versionchanged:: 2.5
This now returns a :class:`InteractionCallbackResponse` instance.
Parameters
-----------
content: Optional[:class:`str`]
@ -735,6 +960,10 @@ class InteractionResponse(Generic[ClientT]):
then it is silently ignored.
.. versionadded:: 2.1
poll: :class:`~discord.Poll`
The poll to send with this message.
.. versionadded:: 2.4
Raises
-------
@ -746,6 +975,11 @@ class InteractionResponse(Generic[ClientT]):
The length of ``embeds`` was invalid.
InteractionResponded
This interaction has already been responded to before.
Returns
-------
:class:`InteractionCallbackResponse`
The interaction callback data.
"""
if self._response_type:
raise InteractionResponded(self._parent)
@ -772,10 +1006,11 @@ class InteractionResponse(Generic[ClientT]):
allowed_mentions=allowed_mentions,
flags=flags,
view=view,
poll=poll,
)
http = parent._state.http
await adapter.create_interaction_response(
response = await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
@ -806,6 +1041,13 @@ class InteractionResponse(Generic[ClientT]):
asyncio.create_task(inner_call())
return InteractionCallbackResponse(
data=response,
parent=self._parent,
state=self._parent._state,
type=self._response_type,
)
async def edit_message(
self,
*,
@ -816,12 +1058,16 @@ class InteractionResponse(Generic[ClientT]):
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = MISSING,
delete_after: Optional[float] = None,
) -> None:
suppress_embeds: bool = MISSING,
) -> Optional[InteractionCallbackResponse[ClientT]]:
"""|coro|
Responds to this interaction by editing the original message of
a component or modal interaction.
.. versionchanged:: 2.5
This now returns a :class:`InteractionCallbackResponse` instance.
Parameters
-----------
content: Optional[:class:`str`]
@ -851,6 +1097,13 @@ class InteractionResponse(Generic[ClientT]):
then it is silently ignored.
.. versionadded:: 2.2
suppress_embeds: :class:`bool`
Whether to suppress embeds for the message. This removes
all the embeds if set to ``True``. If set to ``False``
this brings the embeds back if they were suppressed.
Using this parameter requires :attr:`~.Permissions.manage_messages`.
.. versionadded:: 2.4
Raises
-------
@ -860,6 +1113,11 @@ class InteractionResponse(Generic[ClientT]):
You specified both ``embed`` and ``embeds``.
InteractionResponded
This interaction has already been responded to before.
Returns
-------
Optional[:class:`InteractionCallbackResponse`]
The interaction callback data, or ``None`` if editing the message was not possible.
"""
if self._response_type:
raise InteractionResponded(self._parent)
@ -871,7 +1129,7 @@ class InteractionResponse(Generic[ClientT]):
message_id = msg.id
# If this was invoked via an application command then we can use its original interaction ID
# Since this is used as a cache key for view updates
original_interaction_id = msg.interaction.id if msg.interaction is not None else None
original_interaction_id = msg.interaction_metadata.id if msg.interaction_metadata is not None else None
else:
message_id = None
original_interaction_id = None
@ -882,6 +1140,12 @@ class InteractionResponse(Generic[ClientT]):
if view is not MISSING and message_id is not None:
state.prevent_view_updates_for(message_id)
if suppress_embeds is not MISSING:
flags = MessageFlags._from_value(0)
flags.suppress_embeds = suppress_embeds
else:
flags = MISSING
adapter = async_context.get()
params = interaction_message_response_params(
type=InteractionResponseType.message_update.value,
@ -892,10 +1156,11 @@ class InteractionResponse(Generic[ClientT]):
attachments=attachments,
previous_allowed_mentions=parent._state.allowed_mentions,
allowed_mentions=allowed_mentions,
flags=flags,
)
http = parent._state.http
await adapter.create_interaction_response(
response = await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
@ -920,11 +1185,21 @@ class InteractionResponse(Generic[ClientT]):
asyncio.create_task(inner_call())
async def send_modal(self, modal: Modal, /) -> None:
return InteractionCallbackResponse(
data=response,
parent=self._parent,
state=self._parent._state,
type=self._response_type,
)
async def send_modal(self, modal: Modal, /) -> InteractionCallbackResponse[ClientT]:
"""|coro|
Responds to this interaction by sending a modal.
.. versionchanged:: 2.5
This now returns a :class:`InteractionCallbackResponse` instance.
Parameters
-----------
modal: :class:`~discord.ui.Modal`
@ -936,6 +1211,11 @@ class InteractionResponse(Generic[ClientT]):
Sending the modal failed.
InteractionResponded
This interaction has already been responded to before.
Returns
-------
:class:`InteractionCallbackResponse`
The interaction callback data.
"""
if self._response_type:
raise InteractionResponded(self._parent)
@ -946,7 +1226,7 @@ class InteractionResponse(Generic[ClientT]):
http = parent._state.http
params = interaction_response_params(InteractionResponseType.modal.value, modal.to_dict())
await adapter.create_interaction_response(
response = await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
@ -958,6 +1238,13 @@ class InteractionResponse(Generic[ClientT]):
self._parent._state.store_view(modal)
self._response_type = InteractionResponseType.modal
return InteractionCallbackResponse(
data=response,
parent=self._parent,
state=self._parent._state,
type=self._response_type,
)
async def autocomplete(self, choices: Sequence[Choice[ChoiceT]]) -> None:
"""|coro|
@ -1009,6 +1296,52 @@ class InteractionResponse(Generic[ClientT]):
self._response_type = InteractionResponseType.autocomplete_result
async def launch_activity(self) -> InteractionCallbackResponse[ClientT]:
"""|coro|
Responds to this interaction by launching the activity associated with the app.
Only available for apps with activities enabled.
.. versionadded:: 2.6
Raises
-------
HTTPException
Launching the activity failed.
InteractionResponded
This interaction has already been responded to before.
Returns
-------
:class:`InteractionCallbackResponse`
The interaction callback data.
"""
if self._response_type:
raise InteractionResponded(self._parent)
parent = self._parent
adapter = async_context.get()
http = parent._state.http
params = interaction_response_params(InteractionResponseType.launch_activity.value)
response = await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
proxy=http.proxy,
proxy_auth=http.proxy_auth,
params=params,
)
self._response_type = InteractionResponseType.launch_activity
return InteractionCallbackResponse(
data=response,
parent=self._parent,
state=self._parent._state,
type=self._response_type,
)
class _InteractionMessageState:
__slots__ = ('_parent', '_interaction')
@ -1020,8 +1353,8 @@ class _InteractionMessageState:
def _get_guild(self, guild_id):
return self._parent._get_guild(guild_id)
def store_user(self, data):
return self._parent.store_user(data)
def store_user(self, data, *, cache: bool = True):
return self._parent.store_user(data, cache=cache)
def create_user(self, data):
return self._parent.create_user(data)
@ -1059,6 +1392,7 @@ class InteractionMessage(Message):
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
delete_after: Optional[float] = None,
poll: Poll = MISSING,
) -> InteractionMessage:
"""|coro|
@ -1093,6 +1427,15 @@ class InteractionMessage(Message):
then it is silently ignored.
.. versionadded:: 2.2
poll: :class:`~discord.Poll`
The poll to create when editing the message.
.. versionadded:: 2.5
.. note::
This is only accepted if the interaction response's :attr:`InteractionResponse.type`
attribute is :attr:`InteractionResponseType.deferred_channel_message`.
Raises
-------
@ -1117,6 +1460,7 @@ class InteractionMessage(Message):
attachments=attachments,
view=view,
allowed_mentions=allowed_mentions,
poll=poll,
)
if delete_after is not None:
await self.delete(delay=delete_after)

31
discord/invite.py

@ -29,9 +29,10 @@ from .asset import Asset
from .utils import parse_time, snowflake_time, _get_as_snowflake
from .object import Object
from .mixins import Hashable
from .enums import ChannelType, NSFWLevel, VerificationLevel, InviteTarget, try_enum
from .enums import ChannelType, NSFWLevel, VerificationLevel, InviteTarget, InviteType, try_enum
from .appinfo import PartialAppInfo
from .scheduled_event import ScheduledEvent
from .flags import InviteFlags
__all__ = (
'PartialInviteChannel',
@ -47,6 +48,7 @@ if TYPE_CHECKING:
InviteGuild as InviteGuildPayload,
GatewayInvite as GatewayInvitePayload,
)
from .types.guild import GuildFeature
from .types.channel import (
PartialChannel as InviteChannelPayload,
)
@ -189,7 +191,7 @@ class PartialInviteGuild:
self._state: ConnectionState = state
self.id: int = id
self.name: str = data['name']
self.features: List[str] = data.get('features', [])
self.features: List[GuildFeature] = data.get('features', [])
self._icon: Optional[str] = data.get('icon')
self._banner: Optional[str] = data.get('banner')
self._splash: Optional[str] = data.get('splash')
@ -295,6 +297,10 @@ class Invite(Hashable):
Attributes
-----------
type: :class:`InviteType`
The type of the invite.
.. versionadded: 2.4
max_age: Optional[:class:`int`]
How long before the invite expires in seconds.
A value of ``0`` indicates that it doesn't expire.
@ -373,6 +379,8 @@ class Invite(Hashable):
'expires_at',
'scheduled_event',
'scheduled_event_id',
'type',
'_flags',
)
BASE = 'https://discord.gg'
@ -386,6 +394,7 @@ class Invite(Hashable):
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
):
self._state: ConnectionState = state
self.type: InviteType = try_enum(InviteType, data.get('type', 0))
self.max_age: Optional[int] = data.get('max_age')
self.code: str = data['code']
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild)
@ -425,12 +434,13 @@ class Invite(Hashable):
else None
)
self.scheduled_event_id: Optional[int] = self.scheduled_event.id if self.scheduled_event else None
self._flags: int = data.get('flags', 0)
@classmethod
def from_incomplete(cls, *, state: ConnectionState, data: InvitePayload) -> Self:
guild: Optional[Union[Guild, PartialInviteGuild]]
try:
guild_data = data['guild']
guild_data = data['guild'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
# If we're here, then this is a group DM
guild = None
@ -495,7 +505,7 @@ class Invite(Hashable):
def __repr__(self) -> str:
return (
f'<Invite code={self.code!r} guild={self.guild!r} '
f'<Invite type={self.type} code={self.code!r} guild={self.guild!r} '
f'online={self.approximate_presence_count} '
f'members={self.approximate_member_count}>'
)
@ -516,6 +526,14 @@ class Invite(Hashable):
url += '?event=' + str(self.scheduled_event_id)
return url
@property
def flags(self) -> InviteFlags:
""":class:`InviteFlags`: Returns the flags for this invite.
.. versionadded:: 2.6
"""
return InviteFlags._from_value(self._flags)
def set_scheduled_event(self, scheduled_event: Snowflake, /) -> Self:
"""Sets the scheduled event for this invite.
@ -539,7 +557,7 @@ class Invite(Hashable):
return self
async def delete(self, *, reason: Optional[str] = None) -> None:
async def delete(self, *, reason: Optional[str] = None) -> Self:
"""|coro|
Revokes the instant invite.
@ -561,4 +579,5 @@ class Invite(Hashable):
Revoking the invite failed.
"""
await self._state.http.delete_invite(self.code, reason=reason)
data = await self._state.http.delete_invite(self.code, reason=reason)
return self.from_incomplete(state=self._state, data=data)

232
discord/member.py

@ -35,14 +35,14 @@ import discord.abc
from . import utils
from .asset import Asset
from .utils import MISSING
from .user import BaseUser, User, _UserTag
from .activity import create_activity, ActivityTypes
from .user import BaseUser, ClientUser, User, _UserTag
from .permissions import Permissions
from .enums import Status, try_enum
from .enums import Status
from .errors import ClientException
from .colour import Colour
from .object import Object
from .flags import MemberFlags
from .presences import ClientStatus
__all__ = (
'VoiceState',
@ -57,17 +57,15 @@ if TYPE_CHECKING:
from .channel import DMChannel, VoiceChannel, StageChannel
from .flags import PublicUserFlags
from .guild import Guild
from .types.activity import (
ClientStatus as ClientStatusPayload,
PartialPresenceUpdate,
)
from .activity import ActivityTypes
from .presences import RawPresenceUpdateEvent
from .types.member import (
MemberWithUser as MemberWithUserPayload,
Member as MemberPayload,
UserWithMember as UserWithMemberPayload,
)
from .types.gateway import GuildMemberUpdateEvent
from .types.user import User as UserPayload
from .types.user import User as UserPayload, AvatarDecorationData
from .abc import Snowflake
from .state import ConnectionState
from .message import Message
@ -168,46 +166,6 @@ class VoiceState:
return f'<{self.__class__.__name__} {inner}>'
class _ClientStatus:
__slots__ = ('_status', 'desktop', 'mobile', 'web')
def __init__(self):
self._status: str = 'offline'
self.desktop: Optional[str] = None
self.mobile: Optional[str] = None
self.web: Optional[str] = None
def __repr__(self) -> str:
attrs = [
('_status', self._status),
('desktop', self.desktop),
('mobile', self.mobile),
('web', self.web),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>'
def _update(self, status: str, data: ClientStatusPayload, /) -> None:
self._status = status
self.desktop = data.get('desktop')
self.mobile = data.get('mobile')
self.web = data.get('web')
@classmethod
def _copy(cls, client_status: Self, /) -> Self:
self = cls.__new__(cls) # bypass __init__
self._status = client_status._status
self.desktop = client_status.desktop
self.mobile = client_status.mobile
self.web = client_status.web
return self
def flatten_user(cls: T) -> T:
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods
@ -274,13 +232,14 @@ class Member(discord.abc.Messageable, _UserTag):
.. describe:: str(x)
Returns the member's name with the discriminator.
Returns the member's handle (e.g. ``name`` or ``name#discriminator``).
Attributes
----------
joined_at: Optional[:class:`datetime.datetime`]
An aware datetime object that specifies the date and time in UTC that the member joined the guild.
If the member left and rejoined the guild, this will be the latest date. In certain cases, this can be ``None``.
If the member left and rejoined the guild, this will be the latest date.
This can be ``None``, such as when the member is a guest.
activities: Tuple[Union[:class:`BaseActivity`, :class:`Spotify`]]
The activities that the user is currently doing.
@ -293,7 +252,7 @@ class Member(discord.abc.Messageable, _UserTag):
guild: :class:`Guild`
The guild that the member belongs to.
nick: Optional[:class:`str`]
The guild specific nickname of the user.
The guild specific nickname of the user. Takes precedence over the global name.
pending: :class:`bool`
Whether the member is pending member verification.
@ -303,9 +262,13 @@ class Member(discord.abc.Messageable, _UserTag):
"Nitro boost" on the guild, if available. This could be ``None``.
timed_out_until: Optional[:class:`datetime.datetime`]
An aware datetime object that specifies the date and time in UTC that the member's time out will expire.
This will be set to ``None`` if the user is not timed out.
This will be set to ``None`` or a time in the past if the user is not timed out.
.. versionadded:: 2.0
client_status: :class:`ClientStatus`
Model which holds information about the status of the member on various clients/platforms via presence updates.
.. versionadded:: 2.5
"""
__slots__ = (
@ -318,17 +281,20 @@ class Member(discord.abc.Messageable, _UserTag):
'nick',
'timed_out_until',
'_permissions',
'_client_status',
'client_status',
'_user',
'_state',
'_avatar',
'_banner',
'_flags',
'_avatar_decoration_data',
)
if TYPE_CHECKING:
name: str
id: int
discriminator: str
global_name: Optional[str]
bot: bool
system: bool
created_at: datetime.datetime
@ -341,6 +307,8 @@ class Member(discord.abc.Messageable, _UserTag):
banner: Optional[Asset]
accent_color: Optional[Colour]
accent_colour: Optional[Colour]
avatar_decoration: Optional[Asset]
avatar_decoration_sku_id: Optional[int]
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
self._state: ConnectionState = state
@ -349,15 +317,17 @@ class Member(discord.abc.Messageable, _UserTag):
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at'))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since'))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles']))
self._client_status: _ClientStatus = _ClientStatus()
self.client_status: ClientStatus = ClientStatus()
self.activities: Tuple[ActivityTypes, ...] = ()
self.nick: Optional[str] = data.get('nick', None)
self.pending: bool = data.get('pending', False)
self._avatar: Optional[str] = data.get('avatar')
self._banner: Optional[str] = data.get('banner')
self._permissions: Optional[int]
self._flags: int = data['flags']
self._avatar_decoration_data: Optional[AvatarDecorationData] = data.get('avatar_decoration_data')
try:
self._permissions = int(data['permissions'])
self._permissions = int(data['permissions']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self._permissions = None
@ -368,7 +338,7 @@ class Member(discord.abc.Messageable, _UserTag):
def __repr__(self) -> str:
return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
f'<Member id={self._user.id} name={self._user.name!r} global_name={self._user.global_name!r}'
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
)
@ -387,6 +357,15 @@ class Member(discord.abc.Messageable, _UserTag):
data['user'] = author._to_minimal_user_json() # type: ignore
return cls(data=data, guild=message.guild, state=message._state) # type: ignore
@classmethod
def _from_client_user(cls, *, user: ClientUser, guild: Guild, state: ConnectionState) -> Self:
data = {
'roles': [],
'user': user._to_minimal_user_json(),
'flags': 0,
}
return cls(data=data, guild=guild, state=state) # type: ignore
def _update_from_message(self, data: MemberPayload) -> None:
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
@ -414,7 +393,7 @@ class Member(discord.abc.Messageable, _UserTag):
self._roles = utils.SnowflakeList(member._roles, is_sorted=True)
self.joined_at = member.joined_at
self.premium_since = member.premium_since
self._client_status = _ClientStatus._copy(member._client_status)
self.client_status = member.client_status
self.guild = member.guild
self.nick = member.nick
self.pending = member.pending
@ -424,6 +403,8 @@ class Member(discord.abc.Messageable, _UserTag):
self._permissions = member._permissions
self._state = member._state
self._avatar = member._avatar
self._banner = member._banner
self._avatar_decoration_data = member._avatar_decoration_data
# Reference will not be copied unless necessary by PRESENCE_UPDATE
# See below
@ -438,12 +419,12 @@ class Member(discord.abc.Messageable, _UserTag):
# the nickname change is optional,
# if it isn't in the payload then it didn't change
try:
self.nick = data['nick']
self.nick = data['nick'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
pass
try:
self.pending = data['pending']
self.pending = data['pending'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
pass
@ -451,31 +432,55 @@ class Member(discord.abc.Messageable, _UserTag):
self.timed_out_until = utils.parse_time(data.get('communication_disabled_until'))
self._roles = utils.SnowflakeList(map(int, data['roles']))
self._avatar = data.get('avatar')
self._banner = data.get('banner')
self._flags = data.get('flags', 0)
self._avatar_decoration_data = data.get('avatar_decoration_data')
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(create_activity(d, self._state) for d in data['activities'])
self._client_status._update(data['status'], data['client_status'])
def _presence_update(self, raw: RawPresenceUpdateEvent, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = raw.activities
self.client_status = raw.client_status
if len(user) > 1:
return self._update_inner_user(user)
return None
def _update_inner_user(self, user: UserPayload) -> Optional[Tuple[User, User]]:
u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags)
original = (
u.name,
u.discriminator,
u._avatar,
u.global_name,
u._public_flags,
u._avatar_decoration_data['sku_id'] if u._avatar_decoration_data is not None else None,
)
decoration_payload = user.get('avatar_decoration_data')
# These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0))
modified = (
user['username'],
user['discriminator'],
user['avatar'],
user.get('global_name'),
user.get('public_flags', 0),
decoration_payload['sku_id'] if decoration_payload is not None else None,
)
if original != modified:
to_return = User._copy(self._user)
u.name, u._avatar, u.discriminator, u._public_flags = modified
u.name, u.discriminator, u._avatar, u.global_name, u._public_flags, u._avatar_decoration_data = (
user['username'],
user['discriminator'],
user['avatar'],
user.get('global_name'),
user.get('public_flags', 0),
decoration_payload,
)
# Signal to dispatch on_user_update
return to_return, u
@property
def status(self) -> Status:
""":class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead."""
return try_enum(Status, self._client_status._status)
return self.client_status.status
@property
def raw_status(self) -> str:
@ -483,31 +488,36 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionadded:: 1.5
"""
return self._client_status._status
return self.client_status._status
@status.setter
def status(self, value: Status) -> None:
# internal use only
self._client_status._status = str(value)
self.client_status._status = str(value)
@property
def mobile_status(self) -> Status:
""":class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.mobile or 'offline')
return self.client_status.mobile_status
@property
def desktop_status(self) -> Status:
""":class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.desktop or 'offline')
return self.client_status.desktop_status
@property
def web_status(self) -> Status:
""":class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.web or 'offline')
return self.client_status.web_status
def is_on_mobile(self) -> bool:
""":class:`bool`: A helper function that determines if a member is active on a mobile device."""
return self._client_status.mobile is not None
"""A helper function that determines if a member is active on a mobile device.
Returns
-------
:class:`bool`
"""
return self.client_status.is_on_mobile()
@property
def colour(self) -> Colour:
@ -552,7 +562,9 @@ class Member(discord.abc.Messageable, _UserTag):
role = g.get_role(role_id)
if role:
result.append(role)
result.append(g.default_role)
default_role = g.default_role
if default_role:
result.append(default_role)
result.sort()
return result
@ -581,11 +593,11 @@ class Member(discord.abc.Messageable, _UserTag):
def display_name(self) -> str:
""":class:`str`: Returns the user's display name.
For regular users this is just their username, but
if they have a guild specific nickname then that
For regular users this is just their global name or their username,
but if they have a guild specific nickname then that
is returned instead.
"""
return self.nick or self.name
return self.nick or self.global_name or self.name
@property
def display_avatar(self) -> Asset:
@ -610,6 +622,28 @@ class Member(discord.abc.Messageable, _UserTag):
return None
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
@property
def display_banner(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the member's displayed banner, if any.
This is the member's guild banner if available, otherwise it's their
global banner. If the member has no banner set then ``None`` is returned.
.. versionadded:: 2.5
"""
return self.guild_banner or self._user.banner
@property
def guild_banner(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the guild banner
the member has. If unavailable, ``None`` is returned.
.. versionadded:: 2.5
"""
if self._banner is None:
return None
return Asset._from_guild_banner(self._state, self.guild.id, self.id, self._banner)
@property
def activity(self) -> Optional[ActivityTypes]:
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
@ -833,7 +867,7 @@ class Member(discord.abc.Messageable, _UserTag):
Raises
-------
Forbidden
You do not have the proper permissions to the action requested.
You do not have the proper permissions to do the action requested.
HTTPException
The operation failed.
TypeError
@ -925,7 +959,7 @@ class Member(discord.abc.Messageable, _UserTag):
ClientException
You are not connected to a voice channel.
Forbidden
You do not have the proper permissions to the action requested.
You do not have the proper permissions to do the action requested.
HTTPException
The operation failed.
"""
@ -1011,7 +1045,7 @@ class Member(discord.abc.Messageable, _UserTag):
You must have :attr:`~Permissions.manage_roles` to
use this, and the added :class:`Role`\s must appear lower in the list
of roles than the highest role of the member.
of roles than the highest role of the client.
Parameters
-----------
@ -1050,7 +1084,7 @@ class Member(discord.abc.Messageable, _UserTag):
You must have :attr:`~Permissions.manage_roles` to
use this, and the removed :class:`Role`\s must appear lower in the list
of roles than the highest role of the member.
of roles than the highest role of the client.
Parameters
-----------
@ -1088,6 +1122,40 @@ class Member(discord.abc.Messageable, _UserTag):
for role in roles:
await req(guild_id, user_id, role.id, reason=reason)
async def fetch_voice(self) -> VoiceState:
"""|coro|
Retrieves the current voice state from this member.
.. versionadded:: 2.5
Raises
-------
NotFound
The member is not in a voice channel.
Forbidden
You do not have permissions to get a voice state.
HTTPException
Retrieving the voice state failed.
Returns
-------
:class:`VoiceState`
The current voice state of the member.
"""
guild_id = self.guild.id
if self._state.self_id == self.id:
data = await self._state.http.get_my_voice_state(guild_id)
else:
data = await self._state.http.get_voice_state(guild_id, self.id)
channel_id = data.get('channel_id')
channel: Optional[VocalGuildChannel] = None
if channel_id is not None:
channel = self.guild.get_channel(int(channel_id)) # type: ignore # must be voice channel here
return VoiceState(data=data, channel=channel)
def get_role(self, role_id: int, /) -> Optional[Role]:
"""Returns a role with the given ID from roles which the member has.

930
discord/message.py

File diff suppressed because it is too large

2
discord/object.py

@ -102,7 +102,7 @@ class Object(Hashable):
return f'<Object id={self.id!r} type={self.type!r}>'
def __eq__(self, other: object) -> bool:
if isinstance(other, self.type):
if isinstance(other, (self.type, self.__class__)):
return self.id == other.id
return NotImplemented

2
discord/oggparse.py

@ -99,7 +99,7 @@ class OggStream:
elif not head:
return None
else:
raise OggError('invalid header magic')
raise OggError(f'invalid header magic {head}')
def _iter_pages(self) -> Generator[OggPage, None, None]:
page = self._next_page()

68
discord/opus.py

@ -39,10 +39,17 @@ from .errors import DiscordException
if TYPE_CHECKING:
T = TypeVar('T')
APPLICATION_CTL = Literal['audio', 'voip', 'lowdelay']
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music']
class ApplicationCtl(TypedDict):
audio: int
voip: int
lowdelay: int
class BandCtl(TypedDict):
narrow: int
medium: int
@ -65,6 +72,8 @@ __all__ = (
_log = logging.getLogger(__name__)
OPUS_SILENCE = b'\xF8\xFF\xFE'
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
c_float_ptr = ctypes.POINTER(ctypes.c_float)
@ -90,9 +99,10 @@ OK = 0
BAD_ARG = -1
# Encoder CTLs
APPLICATION_AUDIO = 2049
APPLICATION_VOIP = 2048
APPLICATION_LOWDELAY = 2051
APPLICATION_AUDIO = 'audio'
APPLICATION_VOIP = 'voip'
APPLICATION_LOWDELAY = 'lowdelay'
# These remain as strings for backwards compat
CTL_SET_BITRATE = 4002
CTL_SET_BANDWIDTH = 4008
@ -105,6 +115,12 @@ CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039
# fmt: on
application_ctl: ApplicationCtl = {
'audio': 2049,
'voip': 2048,
'lowdelay': 2051,
}
band_ctl: BandCtl = {
'narrow': 1101,
'medium': 1102,
@ -319,16 +335,38 @@ class _OpusStruct:
class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO):
_OpusStruct.get_opus_version()
self.application: int = application
def __init__(
self,
*,
application: APPLICATION_CTL = 'audio',
bitrate: int = 128,
fec: bool = True,
expected_packet_loss: float = 0.15,
bandwidth: BAND_CTL = 'full',
signal_type: SIGNAL_CTL = 'auto',
):
if application not in application_ctl:
raise ValueError(f'{application} is not a valid application setting. Try one of: {"".join(application_ctl)}')
if not 16 <= bitrate <= 512:
raise ValueError(f'bitrate must be between 16 and 512, not {bitrate}')
if not 0 < expected_packet_loss <= 1.0:
raise ValueError(
f'expected_packet_loss must be a positive number less than or equal to 1, not {expected_packet_loss}'
)
_OpusStruct.get_opus_version() # lazy loads the opus library
self.application: int = application_ctl[application]
self._state: EncoderStruct = self._create_state()
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full')
self.set_signal_type('auto')
self.set_bitrate(bitrate)
self.set_fec(fec)
if fec:
self.set_expected_packet_loss_percent(expected_packet_loss)
self.set_bandwidth(bandwidth)
self.set_signal_type(signal_type)
def __del__(self) -> None:
if hasattr(self, '_state'):
@ -355,7 +393,7 @@ class Encoder(_OpusStruct):
def set_signal_type(self, req: SIGNAL_CTL) -> None:
if req not in signal_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
raise KeyError(f'{req!r} is not a valid signal type setting. Try one of: {",".join(signal_ctl)}')
k = signal_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
@ -454,7 +492,9 @@ class Decoder(_OpusStruct):
channel_count = self.CHANNELS
else:
frames = self.packet_get_nb_frames(data)
channel_count = self.packet_get_nb_channels(data)
# Discord silent frames erroneously present themselves as 1 channel instead of 2
# Therefore we need to hardcode the number instead of using packet_get_nb_channels
channel_count = self.CHANNELS
samples_per_frame = self.packet_get_samples_per_frame(data)
frame_size = frames * samples_per_frame

2
discord/partial_emoji.py

@ -94,7 +94,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
__slots__ = ('animated', 'name', 'id', '_state')
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
_CUSTOM_EMOJI_RE = re.compile(r'<?(?:(?P<animated>a)?:)?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
if TYPE_CHECKING:
id: Optional[int]

195
discord/permissions.py

@ -119,6 +119,12 @@ class Permissions(BaseFlags):
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. describe:: bool(b)
Returns whether the permissions object has any permissions set to ``True``.
.. versionadded:: 2.0
Attributes
-----------
value: :class:`int`
@ -135,9 +141,12 @@ class Permissions(BaseFlags):
self.value = permissions
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid permission name.')
setattr(self, key, value)
try:
flag = self.VALID_FLAGS[key]
except KeyError:
raise TypeError(f'{key!r} is not a valid permission name.') from None
else:
self._set_flag(flag, value)
def is_subset(self, other: Permissions) -> bool:
"""Returns ``True`` if self has the same or fewer permissions as other."""
@ -177,7 +186,8 @@ class Permissions(BaseFlags):
"""A factory method that creates a :class:`Permissions` with all
permissions set to ``True``.
"""
return cls(0b11111111111111111111111111111111111111111)
# Some of these are 0 because we don't want to set unnecessary bits
return cls(0b0000_0000_0000_0110_0111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111)
@classmethod
def _timeout_mask(cls) -> int:
@ -198,13 +208,29 @@ class Permissions(BaseFlags):
base.send_messages_in_threads = False
return base
@classmethod
def _user_installed_permissions(cls, *, in_guild: bool) -> Self:
base = cls.none()
base.send_messages = True
base.attach_files = True
base.embed_links = True
base.external_emojis = True
base.send_voice_messages = True
if in_guild:
# Logically this is False but if not set to True,
# permissions just become 0.
base.read_messages = True
base.send_tts_messages = True
base.send_messages_in_threads = True
return base
@classmethod
def all_channel(cls) -> Self:
"""A :class:`Permissions` with all channel-specific permissions set to
``True`` and the guild-specific ones set to ``False``. The guild-specific
permissions are currently:
- :attr:`manage_emojis`
- :attr:`manage_expressions`
- :attr:`view_audit_log`
- :attr:`view_guild_insights`
- :attr:`manage_guild`
@ -213,6 +239,11 @@ class Permissions(BaseFlags):
- :attr:`kick_members`
- :attr:`ban_members`
- :attr:`administrator`
- :attr:`create_expressions`
- :attr:`moderate_members`
- :attr:`create_events`
- :attr:`manage_events`
- :attr:`view_creator_monetization_analytics`
.. versionchanged:: 1.7
Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_application_commands` permissions.
@ -221,8 +252,15 @@ class Permissions(BaseFlags):
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`use_external_stickers`, :attr:`send_messages_in_threads` and
:attr:`request_to_speak` permissions.
.. versionchanged:: 2.3
Added :attr:`use_soundboard`, :attr:`create_expressions` permissions.
.. versionchanged:: 2.4
Added :attr:`send_polls`, :attr:`send_voice_messages`, attr:`use_external_sounds`,
:attr:`use_embedded_activities`, and :attr:`use_external_apps` permissions.
"""
return cls(0b111110110110011111101111111111101010001)
return cls(0b0000_0000_0000_0110_0110_0100_1111_1101_1011_0011_1111_0111_1111_1111_0101_0001)
@classmethod
def general(cls) -> Self:
@ -234,8 +272,14 @@ class Permissions(BaseFlags):
permissions :attr:`administrator`, :attr:`create_instant_invite`, :attr:`kick_members`,
:attr:`ban_members`, :attr:`change_nickname` and :attr:`manage_nicknames` are
no longer part of the general permissions.
.. versionchanged:: 2.3
Added :attr:`create_expressions` permission.
.. versionchanged:: 2.4
Added :attr:`view_creator_monetization_analytics` permission.
"""
return cls(0b01110000000010000000010010110000)
return cls(0b0000_0000_0000_0000_0000_1010_0000_0000_0111_0000_0000_1000_0000_0100_1011_0000)
@classmethod
def membership(cls) -> Self:
@ -244,7 +288,7 @@ class Permissions(BaseFlags):
.. versionadded:: 1.7
"""
return cls(0b10000000000001100000000000000000000000111)
return cls(0b0000_0000_0000_0000_0000_0001_0000_0000_0000_1100_0000_0000_0000_0000_0000_0111)
@classmethod
def text(cls) -> Self:
@ -258,14 +302,20 @@ class Permissions(BaseFlags):
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions.
.. versionchanged:: 2.3
Added :attr:`send_voice_messages` permission.
.. versionchanged:: 2.4
Added :attr:`send_polls` and :attr:`use_external_apps` permissions.
"""
return cls(0b111110010000000000001111111100001000000)
return cls(0b0000_0000_0000_0110_0100_0000_0111_1100_1000_0000_0000_0111_1111_1000_0100_0000)
@classmethod
def voice(cls) -> Self:
"""A factory method that creates a :class:`Permissions` with all
"Voice" permissions from the official Discord UI set to ``True``."""
return cls(0b1000000000000011111100000000001100000000)
return cls(0b0000_0000_0000_0000_0010_0100_1000_0000_0000_0011_1111_0000_0000_0011_0000_0000)
@classmethod
def stage(cls) -> Self:
@ -290,7 +340,7 @@ class Permissions(BaseFlags):
.. versionchanged:: 2.0
Added :attr:`manage_channels` permission and removed :attr:`request_to_speak` permission.
"""
return cls(0b1010000000000000000010000)
return cls(0b0000_0000_0000_0000_0000_0000_0000_0000_0000_0001_0100_0000_0000_0000_0001_0000)
@classmethod
def elevated(cls) -> Self:
@ -305,13 +355,32 @@ class Permissions(BaseFlags):
- :attr:`manage_messages`
- :attr:`manage_roles`
- :attr:`manage_webhooks`
- :attr:`manage_emojis_and_stickers`
- :attr:`manage_expressions`
- :attr:`manage_threads`
- :attr:`moderate_members`
.. versionadded:: 2.0
"""
return cls(0b10000010001110000000000000010000000111110)
return cls(0b0000_0000_0000_0000_0000_0001_0000_0100_0111_0000_0000_0000_0010_0000_0011_1110)
@classmethod
def apps(cls) -> Self:
"""A factory method that creates a :class:`Permissions` with all
"Apps" permissions from the official Discord UI set to ``True``.
.. versionadded:: 2.6
"""
return cls(0b0000_0000_0000_0100_0000_0000_1000_0000_1000_0000_0000_0000_0000_0000_0000_0000)
@classmethod
def events(cls) -> Self:
"""A factory method that creates a :class:`Permissions` with all
"Events" permissions from the official Discord UI set to ``True``.
.. versionadded:: 2.4
"""
return cls(0b0000_0000_0000_0000_0001_0000_0000_0010_0000_0000_0000_0000_0000_0000_0000_0000)
@classmethod
def advanced(cls) -> Self:
@ -335,8 +404,9 @@ class Permissions(BaseFlags):
A list of key/value pairs to bulk update permissions with.
"""
for key, value in kwargs.items():
if key in self.VALID_FLAGS:
setattr(self, key, value)
flag = self.VALID_FLAGS.get(key)
if flag is not None:
self._set_flag(flag, value)
def handle_overwrite(self, allow: int, deny: int) -> None:
# Basically this is what's happening here.
@ -544,13 +614,21 @@ class Permissions(BaseFlags):
return 1 << 29
@flag_value
def manage_expressions(self) -> int:
""":class:`bool`: Returns ``True`` if a user can edit or delete emojis, stickers, and soundboard sounds.
.. versionadded:: 2.3
"""
return 1 << 30
@make_permission_alias('manage_expressions')
def manage_emojis(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
""":class:`bool`: An alias for :attr:`manage_expressions`."""
return 1 << 30
@make_permission_alias('manage_emojis')
@make_permission_alias('manage_expressions')
def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`.
""":class:`bool`: An alias for :attr:`manage_expressions`.
.. versionadded:: 2.0
"""
@ -644,6 +722,78 @@ class Permissions(BaseFlags):
"""
return 1 << 40
@flag_value
def view_creator_monetization_analytics(self) -> int:
""":class:`bool`: Returns ``True`` if a user can view role subscription insights.
.. versionadded:: 2.4
"""
return 1 << 41
@flag_value
def use_soundboard(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use the soundboard.
.. versionadded:: 2.3
"""
return 1 << 42
@flag_value
def create_expressions(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create emojis, stickers, and soundboard sounds.
.. versionadded:: 2.3
"""
return 1 << 43
@flag_value
def create_events(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create guild events.
.. versionadded:: 2.4
"""
return 1 << 44
@flag_value
def use_external_sounds(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use sounds from other guilds.
.. versionadded:: 2.3
"""
return 1 << 45
@flag_value
def send_voice_messages(self) -> int:
""":class:`bool`: Returns ``True`` if a user can send voice messages.
.. versionadded:: 2.3
"""
return 1 << 46
@flag_value
def send_polls(self) -> int:
""":class:`bool`: Returns ``True`` if a user can send poll messages.
.. versionadded:: 2.4
"""
return 1 << 49
@make_permission_alias('send_polls')
def create_polls(self) -> int:
""":class:`bool`: An alias for :attr:`send_polls`.
.. versionadded:: 2.4
"""
return 1 << 49
@flag_value
def use_external_apps(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use external apps.
.. versionadded:: 2.4
"""
return 1 << 50
def _augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
@ -745,6 +895,7 @@ class PermissionOverwrite:
manage_roles: Optional[bool]
manage_permissions: Optional[bool]
manage_webhooks: Optional[bool]
manage_expressions: Optional[bool]
manage_emojis: Optional[bool]
manage_emojis_and_stickers: Optional[bool]
use_application_commands: Optional[bool]
@ -758,6 +909,14 @@ class PermissionOverwrite:
use_external_stickers: Optional[bool]
use_embedded_activities: Optional[bool]
moderate_members: Optional[bool]
use_soundboard: Optional[bool]
use_external_sounds: Optional[bool]
send_voice_messages: Optional[bool]
create_expressions: Optional[bool]
create_events: Optional[bool]
send_polls: Optional[bool]
create_polls: Optional[bool]
use_external_apps: Optional[bool]
def __init__(self, **kwargs: Optional[bool]):
self._values: Dict[str, Optional[bool]] = {}

163
discord/player.py

@ -25,6 +25,7 @@ from __future__ import annotations
import threading
import subprocess
import warnings
import audioop
import asyncio
import logging
@ -39,7 +40,7 @@ from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, T
from .enums import SpeakingState
from .errors import ClientException
from .opus import Encoder as OpusEncoder
from .opus import Encoder as OpusEncoder, OPUS_SILENCE
from .oggparse import OggStream
from .utils import MISSING
@ -145,6 +146,8 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3
"""
BLOCKSIZE: int = io.DEFAULT_BUFFER_SIZE
def __init__(
self,
source: Union[str, io.BufferedIOBase],
@ -153,12 +156,25 @@ class FFmpegAudio(AudioSource):
args: Any,
**subprocess_kwargs: Any,
):
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
if piping and isinstance(source, str):
piping_stdin = subprocess_kwargs.get('stdin') == subprocess.PIPE
if piping_stdin and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
stderr: Optional[IO[bytes]] = subprocess_kwargs.pop('stderr', None)
if stderr == subprocess.PIPE:
warnings.warn("Passing subprocess.PIPE does nothing", DeprecationWarning, stacklevel=3)
stderr = None
piping_stderr = False
if stderr is not None:
try:
stderr.fileno()
except Exception:
piping_stderr = True
args = [executable, *args]
kwargs = {'stdout': subprocess.PIPE}
kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE if piping_stderr else stderr}
kwargs.update(subprocess_kwargs)
# Ensure attribute is assigned even in the case of errors
@ -166,15 +182,24 @@ class FFmpegAudio(AudioSource):
self._process = self._spawn_process(args, **kwargs)
self._stdout: IO[bytes] = self._process.stdout # type: ignore # process stdout is explicitly set
self._stdin: Optional[IO[bytes]] = None
self._pipe_thread: Optional[threading.Thread] = None
self._stderr: Optional[IO[bytes]] = None
self._pipe_writer_thread: Optional[threading.Thread] = None
self._pipe_reader_thread: Optional[threading.Thread] = None
if piping:
n = f'popen-stdin-writer:{id(self):#x}'
if piping_stdin:
n = f'popen-stdin-writer:pid-{self._process.pid}'
self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread.start()
self._pipe_writer_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_writer_thread.start()
if piping_stderr:
n = f'popen-stderr-reader:pid-{self._process.pid}'
self._stderr = self._process.stderr
self._pipe_reader_thread = threading.Thread(target=self._pipe_reader, args=(stderr,), daemon=True, name=n)
self._pipe_reader_thread.start()
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
_log.debug('Spawning ffmpeg process with command: %s', args)
process = None
try:
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
@ -187,7 +212,8 @@ class FFmpegAudio(AudioSource):
return process
def _kill_process(self) -> None:
proc = self._process
# this function gets called in __del__ so instance attributes might not even exist
proc = getattr(self, '_process', MISSING)
if proc is MISSING:
return
@ -207,10 +233,10 @@ class FFmpegAudio(AudioSource):
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process:
# arbitrarily large read size
data = source.read(8192)
data = source.read(self.BLOCKSIZE)
if not data:
self._process.terminate()
if self._stdin is not None:
self._stdin.close()
return
try:
if self._stdin is not None:
@ -221,9 +247,27 @@ class FFmpegAudio(AudioSource):
self._process.terminate()
return
def _pipe_reader(self, dest: IO[bytes]) -> None:
while self._process:
if self._stderr is None:
return
try:
data: bytes = self._stderr.read(self.BLOCKSIZE)
except Exception:
_log.debug('Read error for %s, this is probably not a problem', self, exc_info=True)
return
if data is None:
return
try:
dest.write(data)
except Exception:
_log.exception('Write error for %s', self)
self._stderr.close()
return
def cleanup(self) -> None:
self._kill_process()
self._process = self._stdout = self._stdin = MISSING
self._process = self._stdout = self._stdin = self._stderr = MISSING
class FFmpegPCMAudio(FFmpegAudio):
@ -244,12 +288,17 @@ class FFmpegPCMAudio(FFmpegAudio):
passed to the stdin of ffmpeg.
executable: :class:`str`
The executable name (and path) to use. Defaults to ``ffmpeg``.
.. warning::
Since this class spawns a subprocess, care should be taken to not
pass in an arbitrary executable name when using this parameter.
pipe: :class:`bool`
If ``True``, denotes that ``source`` parameter will be passed
to the stdin of ffmpeg. Defaults to ``False``.
stderr: Optional[:term:`py:file object`]
A file-like object to pass to the Popen constructor.
Could also be an instance of ``subprocess.PIPE``.
before_options: Optional[:class:`str`]
Extra command line arguments to pass to ffmpeg before the ``-i`` flag.
options: Optional[:class:`str`]
@ -267,7 +316,7 @@ class FFmpegPCMAudio(FFmpegAudio):
*,
executable: str = 'ffmpeg',
pipe: bool = False,
stderr: Optional[IO[str]] = None,
stderr: Optional[IO[bytes]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None,
) -> None:
@ -279,7 +328,14 @@ class FFmpegPCMAudio(FFmpegAudio):
args.append('-i')
args.append('-' if pipe else source)
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
# fmt: off
args.extend(('-f', 's16le',
'-ar', '48000',
'-ac', '2',
'-loglevel', 'warning',
'-blocksize', str(self.BLOCKSIZE)))
# fmt: on
if isinstance(options, str):
args.extend(shlex.split(options))
@ -342,12 +398,17 @@ class FFmpegOpusAudio(FFmpegAudio):
executable: :class:`str`
The executable name (and path) to use. Defaults to ``ffmpeg``.
.. warning::
Since this class spawns a subprocess, care should be taken to not
pass in an arbitrary executable name when using this parameter.
pipe: :class:`bool`
If ``True``, denotes that ``source`` parameter will be passed
to the stdin of ffmpeg. Defaults to ``False``.
stderr: Optional[:term:`py:file object`]
A file-like object to pass to the Popen constructor.
Could also be an instance of ``subprocess.PIPE``.
before_options: Optional[:class:`str`]
Extra command line arguments to pass to ffmpeg before the ``-i`` flag.
options: Optional[:class:`str`]
@ -380,7 +441,7 @@ class FFmpegOpusAudio(FFmpegAudio):
args.append('-i')
args.append('-' if pipe else source)
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus'
codec = 'copy' if codec in ('opus', 'libopus', 'copy') else 'libopus'
bitrate = bitrate if bitrate is not None else 128
# fmt: off
@ -390,7 +451,10 @@ class FFmpegOpusAudio(FFmpegAudio):
'-ar', '48000',
'-ac', '2',
'-b:a', f'{bitrate}k',
'-loglevel', 'warning'))
'-loglevel', 'warning',
'-fec', 'true',
'-packet_loss', '15',
'-blocksize', str(self.BLOCKSIZE)))
# fmt: on
if isinstance(options, str):
@ -524,22 +588,26 @@ class FFmpegOpusAudio(FFmpegAudio):
loop = asyncio.get_running_loop()
try:
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable))
except Exception:
except (KeyboardInterrupt, SystemExit):
raise
except BaseException:
if not fallback:
_log.exception("Probe '%s' using '%s' failed", method, executable)
return # type: ignore
return None, None
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable))
except Exception:
except (KeyboardInterrupt, SystemExit):
raise
except BaseException:
_log.exception("Fallback probe using '%s' failed", executable)
else:
_log.debug("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
else:
_log.debug("Probe found codec=%s, bitrate=%s", codec, bitrate)
finally:
return codec, bitrate
return codec, bitrate
@staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
@ -642,8 +710,7 @@ class AudioPlayer(threading.Thread):
*,
after: Optional[Callable[[Optional[Exception]], Any]] = None,
) -> None:
threading.Thread.__init__(self)
self.daemon: bool = True
super().__init__(daemon=True, name=f'audio-player:{id(self):#x}')
self.source: AudioSource = source
self.client: VoiceClient = client
self.after: Optional[Callable[[Optional[Exception]], Any]] = after
@ -652,7 +719,6 @@ class AudioPlayer(threading.Thread):
self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock()
if after is not None and not callable(after):
@ -663,36 +729,47 @@ class AudioPlayer(threading.Thread):
self._start = time.perf_counter()
# getattr lookup speed ups
play_audio = self.client.send_audio_packet
client = self.client
play_audio = client.send_audio_packet
self._speak(SpeakingState.voice)
while not self._end.is_set():
# are we paused?
if not self._resumed.is_set():
self.send_silence()
# wait until we aren't
self._resumed.wait()
continue
# are we disconnected from voice?
if not self._connected.is_set():
# wait until we are connected
self._connected.wait()
# reset our internal data
self.loops = 0
self._start = time.perf_counter()
self.loops += 1
data = self.source.read()
if not data:
self.stop()
break
# are we disconnected from voice?
if not client.is_connected():
_log.debug('Not connected, waiting for %ss...', client.timeout)
# wait until we are connected, but not forever
connected = client.wait_until_connected(client.timeout)
if self._end.is_set() or not connected:
_log.debug('Aborting playback')
return
_log.debug('Reconnected, resuming playback')
self._speak(SpeakingState.voice)
# reset our internal data
self.loops = 0
self._start = time.perf_counter()
play_audio(data, encode=not self.source.is_opus())
self.loops += 1
next_time = self._start + self.DELAY * self.loops
delay = max(0, self.DELAY + (next_time - time.perf_counter()))
time.sleep(delay)
if client.is_connected():
self.send_silence()
def run(self) -> None:
try:
self._do_run()
@ -738,7 +815,7 @@ class AudioPlayer(threading.Thread):
def is_paused(self) -> bool:
return not self._end.is_set() and not self._resumed.is_set()
def _set_source(self, source: AudioSource) -> None:
def set_source(self, source: AudioSource) -> None:
with self._lock:
self.pause(update_speaking=False)
self.source = source
@ -749,3 +826,11 @@ class AudioPlayer(threading.Thread):
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop)
except Exception:
_log.exception("Speaking call in player failed")
def send_silence(self, count: int = 5) -> None:
try:
for n in range(count):
self.client.send_audio_packet(OPUS_SILENCE, encode=False)
except Exception:
# Any possible error (probably a socket error) is so inconsequential it's not even worth logging
pass

672
discord/poll.py

@ -0,0 +1,672 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional, List, TYPE_CHECKING, Union, AsyncIterator, Dict
import datetime
from .enums import PollLayoutType, try_enum, MessageType
from . import utils
from .emoji import PartialEmoji, Emoji
from .user import User
from .object import Object
from .errors import ClientException
if TYPE_CHECKING:
from typing_extensions import Self
from .message import Message
from .abc import Snowflake
from .state import ConnectionState
from .member import Member
from .types.poll import (
PollCreate as PollCreatePayload,
PollMedia as PollMediaPayload,
PollAnswerCount as PollAnswerCountPayload,
Poll as PollPayload,
PollAnswerWithID as PollAnswerWithIDPayload,
PollResult as PollResultPayload,
PollAnswer as PollAnswerPayload,
)
__all__ = (
'Poll',
'PollAnswer',
'PollMedia',
)
MISSING = utils.MISSING
PollMediaEmoji = Union[PartialEmoji, Emoji, str]
class PollMedia:
"""Represents the poll media for a poll item.
.. versionadded:: 2.4
Attributes
----------
text: :class:`str`
The displayed text.
emoji: Optional[Union[:class:`PartialEmoji`, :class:`Emoji`]]
The attached emoji for this media. This is only valid for poll answers.
"""
__slots__ = ('text', 'emoji')
def __init__(self, /, text: str, emoji: Optional[PollMediaEmoji] = None) -> None:
self.text: str = text
self.emoji: Optional[Union[PartialEmoji, Emoji]] = PartialEmoji.from_str(emoji) if isinstance(emoji, str) else emoji
def __repr__(self) -> str:
return f'<PollMedia text={self.text!r} emoji={self.emoji!r}>'
def to_dict(self) -> PollMediaPayload:
payload: PollMediaPayload = {'text': self.text}
if self.emoji is not None:
payload['emoji'] = self.emoji._to_partial().to_dict()
return payload
@classmethod
def from_dict(cls, *, data: PollMediaPayload) -> Self:
emoji = data.get('emoji')
if emoji:
return cls(text=data['text'], emoji=PartialEmoji.from_dict(emoji))
return cls(text=data['text'])
class PollAnswer:
"""Represents a poll's answer.
.. container:: operations
.. describe:: str(x)
Returns this answer's text, if any.
.. versionadded:: 2.4
Attributes
----------
id: :class:`int`
The ID of this answer.
media: :class:`PollMedia`
The display data for this answer.
self_voted: :class:`bool`
Whether the current user has voted to this answer or not.
"""
__slots__ = (
'media',
'id',
'_state',
'_message',
'_vote_count',
'self_voted',
'_poll',
'_victor',
)
def __init__(
self,
*,
message: Optional[Message],
poll: Poll,
data: PollAnswerWithIDPayload,
) -> None:
self.media: PollMedia = PollMedia.from_dict(data=data['poll_media'])
self.id: int = int(data['answer_id'])
self._message: Optional[Message] = message
self._state: Optional[ConnectionState] = message._state if message else None
self._vote_count: int = 0
self.self_voted: bool = False
self._poll: Poll = poll
self._victor: bool = False
def _handle_vote_event(self, added: bool, self_voted: bool) -> None:
if added:
self._vote_count += 1
else:
self._vote_count -= 1
self.self_voted = self_voted
def _update_with_results(self, payload: PollAnswerCountPayload) -> None:
self._vote_count = int(payload['count'])
self.self_voted = payload['me_voted']
def __str__(self) -> str:
return self.media.text
def __repr__(self) -> str:
return f'<PollAnswer id={self.id} media={self.media!r}>'
@classmethod
def from_params(
cls,
id: int,
text: str,
emoji: Optional[PollMediaEmoji] = None,
*,
poll: Poll,
message: Optional[Message],
) -> Self:
poll_media: PollMediaPayload = {'text': text}
if emoji is not None:
emoji = PartialEmoji.from_str(emoji) if isinstance(emoji, str) else emoji._to_partial()
emoji_data = emoji.to_dict()
# No need to remove animated key as it will be ignored
poll_media['emoji'] = emoji_data
payload: PollAnswerWithIDPayload = {'answer_id': id, 'poll_media': poll_media}
return cls(data=payload, message=message, poll=poll)
@property
def text(self) -> str:
""":class:`str`: Returns this answer's displayed text."""
return self.media.text
@property
def emoji(self) -> Optional[Union[PartialEmoji, Emoji]]:
"""Optional[Union[:class:`Emoji`, :class:`PartialEmoji`]]: Returns this answer's displayed
emoji, if any.
"""
return self.media.emoji
@property
def vote_count(self) -> int:
""":class:`int`: Returns an approximate count of votes for this answer.
If the poll is finished, the count is exact.
"""
return self._vote_count
@property
def poll(self) -> Poll:
""":class:`Poll`: Returns the parent poll of this answer."""
return self._poll
def _to_dict(self) -> PollAnswerPayload:
return {
'poll_media': self.media.to_dict(),
}
@property
def victor(self) -> bool:
""":class:`bool`: Whether the answer is the one that had the most
votes when the poll ended.
.. versionadded:: 2.5
.. note::
If the poll has not ended, this will always return ``False``.
"""
return self._victor
async def voters(
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None
) -> AsyncIterator[Union[User, Member]]:
"""Returns an :term:`asynchronous iterator` representing the users that have voted on this answer.
The ``after`` parameter must represent a user
and meet the :class:`abc.Snowflake` abc.
This can only be called when the parent poll was sent to a message.
Examples
--------
Usage ::
async for voter in poll_answer.voters():
print(f'{voter} has voted for {poll_answer}!')
Flattening into a list: ::
voters = [voter async for voter in poll_answer.voters()]
# voters is now a list of User
Parameters
----------
limit: Optional[:class:`int`]
The maximum number of results to return.
If not provided, returns all the users who
voted on this poll answer.
after: Optional[:class:`abc.Snowflake`]
For pagination, voters are sorted by member.
Raises
------
HTTPException
Retrieving the users failed.
Yields
------
Union[:class:`User`, :class:`Member`]
The member (if retrievable) or the user that has voted
on this poll answer. The case where it can be a :class:`Member`
is in a guild message context. Sometimes it can be a :class:`User`
if the member has left the guild or if the member is not cached.
"""
if not self._message or not self._state: # Make type checker happy
raise ClientException('You cannot fetch users to a poll not sent with a message')
if limit is None:
if not self._message.poll:
limit = 100
else:
limit = self.vote_count or 100
while limit > 0:
retrieve = min(limit, 100)
message = self._message
guild = self._message.guild
state = self._state
after_id = after.id if after else None
data = await state.http.get_poll_answer_voters(
message.channel.id, message.id, self.id, after=after_id, limit=retrieve
)
users = data['users']
if len(users) == 0:
# No more voters to fetch, terminate loop
break
limit -= len(users)
after = Object(id=int(users[-1]['id']))
if not guild or isinstance(guild, Object):
for raw_user in reversed(users):
yield User(state=self._state, data=raw_user)
continue
for raw_member in reversed(users):
member_id = int(raw_member['id'])
member = guild.get_member(member_id)
yield member or User(state=self._state, data=raw_member)
class Poll:
"""Represents a message's Poll.
.. versionadded:: 2.4
Parameters
----------
question: Union[:class:`PollMedia`, :class:`str`]
The poll's displayed question. The text can be up to 300 characters.
duration: :class:`datetime.timedelta`
The duration of the poll. Duration must be in hours.
multiple: :class:`bool`
Whether users are allowed to select more than one answer.
Defaults to ``False``.
layout_type: :class:`PollLayoutType`
The layout type of the poll. Defaults to :attr:`PollLayoutType.default`.
Attributes
-----------
duration: :class:`datetime.timedelta`
The duration of the poll.
multiple: :class:`bool`
Whether users are allowed to select more than one answer.
layout_type: :class:`PollLayoutType`
The layout type of the poll.
"""
__slots__ = (
'multiple',
'_answers',
'duration',
'layout_type',
'_question_media',
'_message',
'_expiry',
'_finalized',
'_state',
'_total_votes',
'_victor_answer_id',
)
def __init__(
self,
question: Union[PollMedia, str],
duration: datetime.timedelta,
*,
multiple: bool = False,
layout_type: PollLayoutType = PollLayoutType.default,
) -> None:
self._question_media: PollMedia = PollMedia(text=question, emoji=None) if isinstance(question, str) else question
self._answers: Dict[int, PollAnswer] = {}
self.duration: datetime.timedelta = duration
self.multiple: bool = multiple
self.layout_type: PollLayoutType = layout_type
# NOTE: These attributes are set manually when calling
# _from_data, so it should be ``None`` now.
self._message: Optional[Message] = None
self._state: Optional[ConnectionState] = None
self._finalized: bool = False
self._expiry: Optional[datetime.datetime] = None
self._total_votes: Optional[int] = None
self._victor_answer_id: Optional[int] = None
def _update(self, message: Message) -> None:
self._state = message._state
self._message = message
if not message.poll:
return
# The message's poll contains the more up to date data.
self._expiry = message.poll.expires_at
self._finalized = message.poll._finalized
self._answers = message.poll._answers
self._update_results_from_message(message)
def _update_results_from_message(self, message: Message) -> None:
if message.type != MessageType.poll_result or not message.embeds:
return
result_embed = message.embeds[0] # Will always have 1 embed
fields: Dict[str, str] = {field.name: field.value for field in result_embed.fields} # type: ignore
total_votes = fields.get('total_votes')
if total_votes is not None:
self._total_votes = int(total_votes)
victor_answer = fields.get('victor_answer_id')
if victor_answer is None:
return # Can't do anything else without the victor answer
self._victor_answer_id = int(victor_answer)
victor_answer_votes = fields['victor_answer_votes']
answer = self._answers[self._victor_answer_id]
answer._victor = True
answer._vote_count = int(victor_answer_votes)
self._answers[answer.id] = answer # Ensure update
def _update_results(self, data: PollResultPayload) -> None:
self._finalized = data['is_finalized']
for count in data['answer_counts']:
answer = self.get_answer(int(count['id']))
if not answer:
continue
answer._update_with_results(count)
def _handle_vote(self, answer_id: int, added: bool, self_voted: bool = False):
answer = self.get_answer(answer_id)
if not answer:
return
answer._handle_vote_event(added, self_voted)
@classmethod
def _from_data(cls, *, data: PollPayload, message: Message, state: ConnectionState) -> Self:
multiselect = data.get('allow_multiselect', False)
layout_type = try_enum(PollLayoutType, data.get('layout_type', 1))
question_data = data.get('question')
question = question_data.get('text')
expiry = utils.parse_time(data['expiry']) # If obtained via API, then expiry is set.
# expiry - message.created_at may be a few nanos away from the actual duration
duration = datetime.timedelta(hours=round((expiry - message.created_at).total_seconds() / 3600))
# self.created_at = message.created_at
self = cls(
duration=duration,
multiple=multiselect,
layout_type=layout_type,
question=question,
)
self._answers = {
int(answer['answer_id']): PollAnswer(data=answer, message=message, poll=self) for answer in data['answers']
}
self._message = message
self._state = state
self._expiry = expiry
try:
self._update_results(data['results'])
except KeyError:
pass
return self
def _to_dict(self) -> PollCreatePayload:
data: PollCreatePayload = {
'allow_multiselect': self.multiple,
'question': self._question_media.to_dict(),
'duration': self.duration.total_seconds() / 3600,
'layout_type': self.layout_type.value,
'answers': [answer._to_dict() for answer in self.answers],
}
return data
def __repr__(self) -> str:
return f"<Poll duration={self.duration} question=\"{self.question}\" answers={self.answers}>"
@property
def question(self) -> str:
""":class:`str`: Returns this poll's question string."""
return self._question_media.text
@property
def answers(self) -> List[PollAnswer]:
"""List[:class:`PollAnswer`]: Returns a read-only copy of the answers."""
return list(self._answers.values())
@property
def victor_answer_id(self) -> Optional[int]:
"""Optional[:class:`int`]: The victor answer ID.
.. versionadded:: 2.5
.. note::
This will **always** be ``None`` for polls that have not yet finished.
"""
return self._victor_answer_id
@property
def victor_answer(self) -> Optional[PollAnswer]:
"""Optional[:class:`PollAnswer`]: The victor answer.
.. versionadded:: 2.5
.. note::
This will **always** be ``None`` for polls that have not yet finished.
"""
if self.victor_answer_id is None:
return None
return self.get_answer(self.victor_answer_id)
@property
def expires_at(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: A datetime object representing the poll expiry.
.. note::
This will **always** be ``None`` for stateless polls.
"""
return self._expiry
@property
def created_at(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: Returns the poll's creation time.
.. note::
This will **always** be ``None`` for stateless polls.
"""
if not self._message:
return
return self._message.created_at
@property
def message(self) -> Optional[Message]:
"""Optional[:class:`Message`]: The message this poll is from."""
return self._message
@property
def total_votes(self) -> int:
""":class:`int`: Returns the sum of all the answer votes.
If the poll has not yet finished, this is an approximate vote count.
.. versionchanged:: 2.5
This now returns an exact vote count when updated from its poll results message.
"""
if self._total_votes is not None:
return self._total_votes
return sum([answer.vote_count for answer in self.answers])
def is_finalised(self) -> bool:
""":class:`bool`: Returns whether the poll has finalised.
This always returns ``False`` for stateless polls.
"""
return self._finalized
is_finalized = is_finalised
def copy(self) -> Self:
"""Returns a stateless copy of this poll.
This is meant to be used when you want to edit a stateful poll.
Returns
-------
:class:`Poll`
The copy of the poll.
"""
new = self.__class__(question=self.question, duration=self.duration)
# We want to return a stateless copy of the poll, so we should not
# override new._answers as our answers may contain a state
for answer in self.answers:
new.add_answer(text=answer.text, emoji=answer.emoji)
return new
def add_answer(
self,
*,
text: str,
emoji: Optional[Union[PartialEmoji, Emoji, str]] = None,
) -> Self:
"""Appends a new answer to this poll.
Parameters
----------
text: :class:`str`
The text label for this poll answer. Can be up to 55
characters.
emoji: Union[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`]
The emoji to display along the text.
Raises
------
ClientException
Cannot append answers to a poll that is active.
Returns
-------
:class:`Poll`
This poll with the new answer appended. This allows fluent-style chaining.
"""
if self._message:
raise ClientException('Cannot append answers to a poll that is active')
answer = PollAnswer.from_params(id=len(self.answers) + 1, text=text, emoji=emoji, message=self._message, poll=self)
self._answers[answer.id] = answer
return self
def get_answer(
self,
/,
id: int,
) -> Optional[PollAnswer]:
"""Returns the answer with the provided ID or ``None`` if not found.
Parameters
----------
id: :class:`int`
The ID of the answer to get.
Returns
-------
Optional[:class:`PollAnswer`]
The answer.
"""
return self._answers.get(id)
async def end(self) -> Self:
"""|coro|
Ends the poll.
Raises
------
ClientException
This poll has no attached message.
HTTPException
Ending the poll failed.
Returns
-------
:class:`Poll`
The updated poll.
"""
if not self._message or not self._state: # Make type checker happy
raise ClientException('This poll has no attached message.')
message = await self._message.end_poll()
self._update(message)
return self

150
discord/presences.py

@ -0,0 +1,150 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Tuple
from .activity import create_activity
from .enums import Status, try_enum
from .utils import MISSING, _get_as_snowflake, _RawReprMixin
if TYPE_CHECKING:
from typing_extensions import Self
from .activity import ActivityTypes
from .guild import Guild
from .state import ConnectionState
from .types.activity import ClientStatus as ClientStatusPayload, PartialPresenceUpdate
__all__ = (
'RawPresenceUpdateEvent',
'ClientStatus',
)
class ClientStatus:
"""Represents the :ddocs:`Client Status Object <events/gateway-events#client-status-object>` from Discord,
which holds information about the status of the user on various clients/platforms, with additional helpers.
.. versionadded:: 2.5
"""
__slots__ = ('_status', 'desktop', 'mobile', 'web')
def __init__(self, *, status: str = MISSING, data: ClientStatusPayload = MISSING) -> None:
self._status: str = status or 'offline'
data = data or {}
self.desktop: Optional[str] = data.get('desktop')
self.mobile: Optional[str] = data.get('mobile')
self.web: Optional[str] = data.get('web')
def __repr__(self) -> str:
attrs = [
('_status', self._status),
('desktop', self.desktop),
('mobile', self.mobile),
('web', self.web),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>'
def _update(self, status: str, data: ClientStatusPayload, /) -> None:
self._status = status
self.desktop = data.get('desktop')
self.mobile = data.get('mobile')
self.web = data.get('web')
@classmethod
def _copy(cls, client_status: Self, /) -> Self:
self = cls.__new__(cls) # bypass __init__
self._status = client_status._status
self.desktop = client_status.desktop
self.mobile = client_status.mobile
self.web = client_status.web
return self
@property
def status(self) -> Status:
""":class:`Status`: The user's overall status. If the value is unknown, then it will be a :class:`str` instead."""
return try_enum(Status, self._status)
@property
def raw_status(self) -> str:
""":class:`str`: The user's overall status as a string value."""
return self._status
@property
def mobile_status(self) -> Status:
""":class:`Status`: The user's status on a mobile device, if applicable."""
return try_enum(Status, self.mobile or 'offline')
@property
def desktop_status(self) -> Status:
""":class:`Status`: The user's status on the desktop client, if applicable."""
return try_enum(Status, self.desktop or 'offline')
@property
def web_status(self) -> Status:
""":class:`Status`: The user's status on the web client, if applicable."""
return try_enum(Status, self.web or 'offline')
def is_on_mobile(self) -> bool:
""":class:`bool`: A helper function that determines if a user is active on a mobile device."""
return self.mobile is not None
class RawPresenceUpdateEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_presence_update` event.
.. versionadded:: 2.5
Attributes
----------
user_id: :class:`int`
The ID of the user that triggered the presence update.
guild_id: Optional[:class:`int`]
The guild ID for the users presence update. Could be ``None``.
guild: Optional[:class:`Guild`]
The guild associated with the presence update and user. Could be ``None``.
client_status: :class:`ClientStatus`
The :class:`~.ClientStatus` model which holds information about the status of the user on various clients.
activities: Tuple[Union[:class:`BaseActivity`, :class:`Spotify`]]
The activities the user is currently doing. Due to a Discord API limitation, a user's Spotify activity may not appear
if they are listening to a song with a title longer than ``128`` characters. See :issue:`1738` for more information.
"""
__slots__ = ('user_id', 'guild_id', 'guild', 'client_status', 'activities')
def __init__(self, *, data: PartialPresenceUpdate, state: ConnectionState) -> None:
self.user_id: int = int(data['user']['id'])
self.client_status: ClientStatus = ClientStatus(status=data['status'], data=data['client_status'])
self.activities: Tuple[ActivityTypes, ...] = tuple(create_activity(d, state) for d in data['activities'])
self.guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
self.guild: Optional[Guild] = state._get_guild(self.guild_id)

148
discord/raw_models.py

@ -25,13 +25,16 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Set, List, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Set, List, Union
from .enums import ChannelType, try_enum
from .utils import _get_as_snowflake
from .enums import ChannelType, try_enum, ReactionType
from .utils import _get_as_snowflake, _RawReprMixin
from .app_commands import AppCommandPermissions
from .colour import Colour
if TYPE_CHECKING:
from typing_extensions import Self
from .types.gateway import (
MessageDeleteEvent,
MessageDeleteBulkEvent as BulkMessageDeleteEvent,
@ -46,6 +49,7 @@ if TYPE_CHECKING:
ThreadMembersUpdate,
TypingStartEvent,
GuildMemberRemoveEvent,
PollVoteActionEvent,
)
from .types.command import GuildApplicationCommandPermissions
from .message import Message
@ -57,6 +61,7 @@ if TYPE_CHECKING:
from .guild import Guild
ReactionActionEvent = Union[MessageReactionAddEvent, MessageReactionRemoveEvent]
ReactionActionType = Literal['REACTION_ADD', 'REACTION_REMOVE']
__all__ = (
@ -73,17 +78,10 @@ __all__ = (
'RawTypingEvent',
'RawMemberRemoveEvent',
'RawAppCommandPermissionsUpdateEvent',
'RawPollVoteActionEvent',
)
class _RawReprMixin:
__slots__: Tuple[str, ...] = ()
def __repr__(self) -> str:
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>'
class RawMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_message_delete` event.
@ -106,7 +104,7 @@ class RawMessageDeleteEvent(_RawReprMixin):
self.channel_id: int = int(data['channel_id'])
self.cached_message: Optional[Message] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.guild_id: Optional[int] = None
@ -134,7 +132,7 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
self.cached_messages: List[Message] = []
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.guild_id: Optional[int] = None
@ -156,24 +154,26 @@ class RawMessageUpdateEvent(_RawReprMixin):
.. versionadded:: 1.7
data: :class:`dict`
The raw data given by the :ddocs:`gateway <topics/gateway#message-update>`
The raw data given by the :ddocs:`gateway <topics/gateway-events#message-update>`
cached_message: Optional[:class:`Message`]
The cached message, if found in the internal message cache. Represents the message before
it is modified by the data in :attr:`RawMessageUpdateEvent.data`.
message: :class:`Message`
The updated message.
.. versionadded:: 2.5
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message', 'message')
def __init__(self, data: MessageUpdateEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
def __init__(self, data: MessageUpdateEvent, message: Message) -> None:
self.message_id: int = message.id
self.channel_id: int = message.channel.id
self.data: MessageUpdateEvent = data
self.message: Message = message
self.cached_message: Optional[Message] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id: Optional[int] = message.guild.id if message.guild else None
class RawReactionActionEvent(_RawReprMixin):
@ -196,30 +196,70 @@ class RawReactionActionEvent(_RawReprMixin):
The member who added the reaction. Only available if ``event_type`` is ``REACTION_ADD`` and the reaction is inside a guild.
.. versionadded:: 1.3
message_author_id: Optional[:class:`int`]
The author ID of the message being reacted to. Only available if ``event_type`` is ``REACTION_ADD``.
.. versionadded:: 2.4
event_type: :class:`str`
The event type that triggered this action. Can be
``REACTION_ADD`` for reaction addition or
``REACTION_REMOVE`` for reaction removal.
.. versionadded:: 1.3
burst: :class:`bool`
Whether the reaction was a burst reaction, also known as a "super reaction".
.. versionadded:: 2.4
burst_colours: List[:class:`Colour`]
A list of colours used for burst reaction animation. Only available if ``burst`` is ``True``
and if ``event_type`` is ``REACTION_ADD``.
.. versionadded:: 2.0
type: :class:`ReactionType`
The type of the reaction.
.. versionadded:: 2.4
"""
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', 'event_type', 'member')
__slots__ = (
'message_id',
'user_id',
'channel_id',
'guild_id',
'emoji',
'event_type',
'member',
'message_author_id',
'burst',
'burst_colours',
'type',
)
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: ReactionActionType) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.user_id: int = int(data['user_id'])
self.emoji: PartialEmoji = emoji
self.event_type: str = event_type
self.event_type: ReactionActionType = event_type
self.member: Optional[Member] = None
self.message_author_id: Optional[int] = _get_as_snowflake(data, 'message_author_id')
self.burst: bool = data.get('burst', False)
self.burst_colours: List[Colour] = [Colour.from_str(c) for c in data.get('burst_colours', [])]
self.type: ReactionType = try_enum(ReactionType, data['type'])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.guild_id: Optional[int] = None
@property
def burst_colors(self) -> List[Colour]:
"""An alias of :attr:`burst_colours`.
.. versionadded:: 2.4
"""
return self.burst_colours
class RawReactionClearEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear` event.
@ -241,7 +281,7 @@ class RawReactionClearEvent(_RawReprMixin):
self.channel_id: int = int(data['channel_id'])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.guild_id: Optional[int] = None
@ -271,7 +311,7 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
self.channel_id: int = int(data['channel_id'])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data['guild_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.guild_id: Optional[int] = None
@ -298,7 +338,9 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
self.guild_id: int = int(data['guild_id'])
try:
self.application_id: Optional[int] = int(data['application_id'])
self.application_id: Optional[int] = int(
data['application_id'] # pyright: ignore[reportTypedDictNotRequiredAccess]
)
except KeyError:
self.application_id: Optional[int] = None
@ -319,7 +361,7 @@ class RawThreadUpdateEvent(_RawReprMixin):
parent_id: :class:`int`
The ID of the channel the thread belongs to.
data: :class:`dict`
The raw data given by the :ddocs:`gateway <topics/gateway#thread-update>`
The raw data given by the :ddocs:`gateway <topics/gateway-events#thread-update>`
thread: Optional[:class:`discord.Thread`]
The thread, if it could be found in the internal cache.
"""
@ -363,6 +405,20 @@ class RawThreadDeleteEvent(_RawReprMixin):
self.parent_id: int = int(data['parent_id'])
self.thread: Optional[Thread] = None
@classmethod
def _from_thread(cls, thread: Thread) -> Self:
data: ThreadDeleteEvent = {
'id': thread.id,
'type': thread.type.value,
'guild_id': thread.guild.id,
'parent_id': thread.parent_id,
}
instance = cls(data)
instance.thread = thread
return instance
class RawThreadMembersUpdate(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_thread_member_remove` event.
@ -378,7 +434,7 @@ class RawThreadMembersUpdate(_RawReprMixin):
member_count: :class:`int`
The approximate number of members in the thread. This caps at 50.
data: :class:`dict`
The raw data given by the :ddocs:`gateway <topics/gateway#thread-members-update>`.
The raw data given by the :ddocs:`gateway <topics/gateway-events#thread-members-update>`.
"""
__slots__ = ('thread_id', 'guild_id', 'member_count', 'data')
@ -467,3 +523,33 @@ class RawAppCommandPermissionsUpdateEvent(_RawReprMixin):
self.permissions: List[AppCommandPermissions] = [
AppCommandPermissions(data=perm, guild=self.guild, state=state) for perm in data['permissions']
]
class RawPollVoteActionEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_poll_vote_add` or :func:`on_raw_poll_vote_remove`
event.
.. versionadded:: 2.4
Attributes
----------
user_id: :class:`int`
The ID of the user that added or removed a vote.
channel_id: :class:`int`
The channel ID where the poll vote action took place.
message_id: :class:`int`
The message ID that contains the poll the user added or removed their vote on.
guild_id: Optional[:class:`int`]
The guild ID where the vote got added or removed, if applicable..
answer_id: :class:`int`
The poll answer's ID the user voted on.
"""
__slots__ = ('user_id', 'channel_id', 'message_id', 'guild_id', 'answer_id')
def __init__(self, data: PollVoteActionEvent) -> None:
self.user_id: int = int(data['user_id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data['message_id'])
self.guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
self.answer_id: int = int(data['answer_id'])

42
discord/reaction.py

@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, AsyncIterator, Union, Optional
from .user import User
from .object import Object
from .enums import ReactionType
# fmt: off
__all__ = (
@ -74,22 +75,41 @@ class Reaction:
emoji: Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`]
The reaction emoji. May be a custom emoji, or a unicode emoji.
count: :class:`int`
Number of times this reaction was made
Number of times this reaction was made. This is a sum of :attr:`normal_count` and :attr:`burst_count`.
me: :class:`bool`
If the user sent this reaction.
message: :class:`Message`
Message this reaction is for.
me_burst: :class:`bool`
If the user sent this super reaction.
.. versionadded:: 2.4
normal_count: :class:`int`
The number of times this reaction was made using normal reactions.
This is not available in the gateway events such as :func:`on_reaction_add`
or :func:`on_reaction_remove`.
.. versionadded:: 2.4
burst_count: :class:`int`
The number of times this reaction was made using super reactions.
This is not available in the gateway events such as :func:`on_reaction_add`
or :func:`on_reaction_remove`.
.. versionadded:: 2.4
"""
__slots__ = ('message', 'count', 'emoji', 'me')
__slots__ = ('message', 'count', 'emoji', 'me', 'me_burst', 'normal_count', 'burst_count')
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None):
self.message: Message = message
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji'])
self.count: int = data.get('count', 1)
self.me: bool = data['me']
details = data.get('count_details', {})
self.normal_count: int = details.get('normal', 0)
self.burst_count: int = details.get('burst', 0)
self.me_burst: bool = data.get('me_burst', False)
# TODO: typeguard
def is_custom_emoji(self) -> bool:
""":class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str)
@ -166,7 +186,7 @@ class Reaction:
await self.message.clear_reaction(self.emoji)
async def users(
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None, type: Optional[ReactionType] = None
) -> AsyncIterator[Union[Member, User]]:
"""Returns an :term:`asynchronous iterator` representing the users that have reacted to the message.
@ -201,6 +221,11 @@ class Reaction:
reacted to the message.
after: Optional[:class:`abc.Snowflake`]
For pagination, reactions are sorted by member.
type: Optional[:class:`ReactionType`]
The type of reaction to return users from.
If not provided, Discord only returns users of reactions with type ``normal``.
.. versionadded:: 2.4
Raises
--------
@ -232,7 +257,14 @@ class Reaction:
state = message._state
after_id = after.id if after else None
data = await state.http.get_reaction_users(message.channel.id, message.id, emoji, retrieve, after=after_id)
data = await state.http.get_reaction_users(
message.channel.id,
message.id,
emoji,
retrieve,
after=after_id,
type=type.value if type is not None else None,
)
if data:
limit -= len(data)

121
discord/role.py

@ -23,13 +23,14 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Union, overload, TYPE_CHECKING
from .asset import Asset
from .permissions import Permissions
from .colour import Colour
from .mixins import Hashable
from .utils import snowflake_time, _bytes_to_base64_data, _get_as_snowflake, MISSING
from .flags import RoleFlags
__all__ = (
'RoleTags',
@ -219,6 +220,7 @@ class Role(Hashable):
'hoist',
'guild',
'tags',
'_flags',
'_state',
)
@ -281,9 +283,10 @@ class Role(Hashable):
self.managed: bool = data.get('managed', False)
self.mentionable: bool = data.get('mentionable', False)
self.tags: Optional[RoleTags]
self._flags: int = data.get('flags', 0)
try:
self.tags = RoleTags(data['tags'])
self.tags = RoleTags(data['tags']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.tags = None
@ -379,6 +382,14 @@ class Role(Hashable):
role_id = self.id
return [member for member in all_members if member._roles.has(role_id)]
@property
def flags(self) -> RoleFlags:
""":class:`RoleFlags`: Returns the role's flags.
.. versionadded:: 2.4
"""
return RoleFlags._from_value(self._flags)
async def _move(self, position: int, reason: Optional[str]) -> None:
if position <= 0:
raise ValueError("Cannot move role to position 0 or below")
@ -511,6 +522,112 @@ class Role(Hashable):
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
return Role(guild=self.guild, data=data, state=self._state)
@overload
async def move(self, *, beginning: bool, offset: int = ..., reason: Optional[str] = ...):
...
@overload
async def move(self, *, end: bool, offset: int = ..., reason: Optional[str] = ...):
...
@overload
async def move(self, *, above: Role, offset: int = ..., reason: Optional[str] = ...):
...
@overload
async def move(self, *, below: Role, offset: int = ..., reason: Optional[str] = ...):
...
async def move(
self,
*,
beginning: bool = MISSING,
end: bool = MISSING,
above: Role = MISSING,
below: Role = MISSING,
offset: int = 0,
reason: Optional[str] = None,
):
"""|coro|
A rich interface to help move a role relative to other roles.
You must have :attr:`~discord.Permissions.manage_roles` to do this,
and you cannot move roles above the client's top role in the guild.
.. versionadded:: 2.5
Parameters
-----------
beginning: :class:`bool`
Whether to move this at the beginning of the role list, above the default role.
This is mutually exclusive with `end`, `above`, and `below`.
end: :class:`bool`
Whether to move this at the end of the role list.
This is mutually exclusive with `beginning`, `above`, and `below`.
above: :class:`Role`
The role that should be above our current role.
This mutually exclusive with `beginning`, `end`, and `below`.
below: :class:`Role`
The role that should be below our current role.
This mutually exclusive with `beginning`, `end`, and `above`.
offset: :class:`int`
The number of roles to offset the move by. For example,
an offset of ``2`` with ``beginning=True`` would move
it 2 above the beginning. A positive number moves it above
while a negative number moves it below. Note that this
number is relative and computed after the ``beginning``,
``end``, ``before``, and ``after`` parameters.
reason: Optional[:class:`str`]
The reason for editing this role. Shows up on the audit log.
Raises
-------
Forbidden
You cannot move the role there, or lack permissions to do so.
HTTPException
Moving the role failed.
TypeError
A bad mix of arguments were passed.
ValueError
An invalid role was passed.
Returns
--------
List[:class:`Role`]
A list of all the roles in the guild.
"""
if sum(bool(a) for a in (beginning, end, above, below)) > 1:
raise TypeError('Only one of [beginning, end, above, below] can be used.')
target = above or below
guild = self.guild
guild_roles = guild.roles
if target:
if target not in guild_roles:
raise ValueError('Target role is from a different guild')
if above == guild.default_role:
raise ValueError('Role cannot be moved below the default role')
if self == target:
raise ValueError('Target role cannot be itself')
roles = [r for r in guild_roles if r != self]
if beginning:
index = 1
elif end:
index = len(roles)
elif above in roles:
index = roles.index(above)
elif below in roles:
index = roles.index(below) + 1
else:
index = guild_roles.index(self)
roles.insert(max((index + offset), 1), self)
payload: List[RolePositionUpdate] = [{'id': role.id, 'position': idx} for idx, role in enumerate(roles)]
await self._state.http.move_role_position(guild.id, payload, reason=reason)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|

120
discord/scheduled_event.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional, Union
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional, Union, overload, Literal
from .asset import Asset
from .enums import EventStatus, EntityType, PrivacyLevel, try_enum
@ -298,6 +298,87 @@ class ScheduledEvent(Hashable):
return await self.__modify_status(EventStatus.cancelled, reason)
@overload
async def edit(
self,
*,
name: str = ...,
description: str = ...,
start_time: datetime = ...,
end_time: Optional[datetime] = ...,
privacy_level: PrivacyLevel = ...,
status: EventStatus = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def edit(
self,
*,
name: str = ...,
description: str = ...,
channel: Snowflake,
start_time: datetime = ...,
end_time: Optional[datetime] = ...,
privacy_level: PrivacyLevel = ...,
entity_type: Literal[EntityType.voice, EntityType.stage_instance],
status: EventStatus = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def edit(
self,
*,
name: str = ...,
description: str = ...,
start_time: datetime = ...,
end_time: datetime = ...,
privacy_level: PrivacyLevel = ...,
entity_type: Literal[EntityType.external],
status: EventStatus = ...,
image: bytes = ...,
location: str,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def edit(
self,
*,
name: str = ...,
description: str = ...,
channel: Union[VoiceChannel, StageChannel],
start_time: datetime = ...,
end_time: Optional[datetime] = ...,
privacy_level: PrivacyLevel = ...,
status: EventStatus = ...,
image: bytes = ...,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
@overload
async def edit(
self,
*,
name: str = ...,
description: str = ...,
start_time: datetime = ...,
end_time: datetime = ...,
privacy_level: PrivacyLevel = ...,
status: EventStatus = ...,
image: bytes = ...,
location: str,
reason: Optional[str] = ...,
) -> ScheduledEvent:
...
async def edit(
self,
*,
@ -414,24 +495,34 @@ class ScheduledEvent(Hashable):
payload['image'] = image_as_str
entity_type = entity_type or getattr(channel, '_scheduled_event_entity_type', MISSING)
if entity_type is None:
raise TypeError(
f'invalid GuildChannel type passed, must be VoiceChannel or StageChannel not {channel.__class__.__name__}'
)
if entity_type is not MISSING:
if entity_type is MISSING:
if channel and isinstance(channel, Object):
if channel.type is VoiceChannel:
entity_type = EntityType.voice
elif channel.type is StageChannel:
entity_type = EntityType.stage_instance
elif location not in (MISSING, None):
entity_type = EntityType.external
else:
if not isinstance(entity_type, EntityType):
raise TypeError('entity_type must be of type EntityType')
payload['entity_type'] = entity_type.value
if entity_type is None:
raise TypeError(
f'invalid GuildChannel type passed, must be VoiceChannel or StageChannel not {channel.__class__.__name__}'
)
_entity_type = entity_type or self.entity_type
_entity_type_changed = _entity_type is not self.entity_type
if _entity_type in (EntityType.stage_instance, EntityType.voice):
if channel is MISSING or channel is None:
raise TypeError('channel must be set when entity_type is voice or stage_instance')
payload['channel_id'] = channel.id
if _entity_type_changed:
raise TypeError('channel must be set when entity_type is voice or stage_instance')
else:
payload['channel_id'] = channel.id
if location not in (MISSING, None):
raise TypeError('location cannot be set when entity_type is voice or stage_instance')
@ -442,11 +533,12 @@ class ScheduledEvent(Hashable):
payload['channel_id'] = None
if location is MISSING or location is None:
raise TypeError('location must be set when entity_type is external')
metadata['location'] = location
if _entity_type_changed:
raise TypeError('location must be set when entity_type is external')
else:
metadata['location'] = location
if end_time is MISSING or end_time is None:
if not self.end_time and (end_time is MISSING or end_time is None):
raise TypeError('end_time must be set when entity_type is external')
if end_time is not MISSING:

96
discord/shard.py

@ -47,13 +47,16 @@ from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict
if TYPE_CHECKING:
from typing_extensions import Unpack
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .flags import Intents
from .types.gateway import SessionStartLimit
__all__ = (
'AutoShardedClient',
'ShardInfo',
'SessionStartLimits',
)
_log = logging.getLogger(__name__)
@ -192,6 +195,10 @@ class Shard:
self.ws = await asyncio.wait_for(coro, timeout=60.0)
except self._handled_exceptions as e:
await self._handle_disconnect(e)
except ReconnectWebSocket as e:
_log.debug('Somehow got a signal to %s while trying to %s shard ID %s.', e.op, exc.op, self.id)
op = EventType.resume if e.resume else EventType.identify
self._queue_put(EventItem(op, self, e))
except asyncio.CancelledError:
return
except Exception as e:
@ -289,6 +296,32 @@ class ShardInfo:
return self._parent.ws.is_ratelimited()
class SessionStartLimits:
"""A class that holds info about session start limits
.. versionadded:: 2.5
Attributes
----------
total: :class:`int`
The total number of session starts the current user is allowed
remaining: :class:`int`
Remaining remaining number of session starts the current user is allowed
reset_after: :class:`int`
The number of milliseconds until the limit resets
max_concurrency: :class:`int`
The number of identify requests allowed per 5 seconds
"""
__slots__ = ("total", "remaining", "reset_after", "max_concurrency")
def __init__(self, **kwargs: Unpack[SessionStartLimit]):
self.total: int = kwargs['total']
self.remaining: int = kwargs['remaining']
self.reset_after: int = kwargs['reset_after']
self.max_concurrency: int = kwargs['max_concurrency']
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
@ -322,6 +355,11 @@ class AutoShardedClient(Client):
------------
shard_ids: Optional[List[:class:`int`]]
An optional list of shard_ids to launch the shards with.
shard_connect_timeout: Optional[:class:`float`]
The maximum number of seconds to wait before timing out when launching a shard.
Defaults to 180 seconds.
.. versionadded:: 2.4
"""
if TYPE_CHECKING:
@ -330,6 +368,8 @@ class AutoShardedClient(Client):
def __init__(self, *args: Any, intents: Intents, **kwargs: Any) -> None:
kwargs.pop('shard_id', None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
self.shard_connect_timeout: Optional[float] = kwargs.pop('shard_connect_timeout', 180.0)
super().__init__(*args, intents=intents, **kwargs)
if self.shard_ids is not None:
@ -404,10 +444,37 @@ class AutoShardedClient(Client):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
async def fetch_session_start_limits(self) -> SessionStartLimits:
"""|coro|
Get the session start limits.
This is not typically needed, and will be handled for you by default.
At the point where you are launching multiple instances
with manual shard ranges and are considered required to use large bot
sharding by Discord, this function when used along IPC and a
before_identity_hook can speed up session start.
.. versionadded:: 2.5
Returns
-------
:class:`SessionStartLimits`
A class containing the session start limits
Raises
------
GatewayNotFound
The gateway was unreachable
"""
_, _, limits = await self.http.get_bot_gateway()
return SessionStartLimits(**limits)
async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
ws = await asyncio.wait_for(coro, timeout=self.shard_connect_timeout)
except Exception:
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0)
@ -423,7 +490,7 @@ class AutoShardedClient(Client):
if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway_url = await self.http.get_bot_gateway()
self.shard_count, gateway_url, _session_start_limit = await self.http.get_bot_gateway()
gateway = yarl.URL(gateway_url)
else:
gateway = DiscordWebSocket.DEFAULT_GATEWAY
@ -450,10 +517,10 @@ class AutoShardedClient(Client):
if item.type == EventType.close:
await self.close()
if isinstance(item.error, ConnectionClosed):
if item.error.code != 1000:
raise item.error
if item.error.code == 4014:
raise PrivilegedIntentsRequired(item.shard.id) from None
if item.error.code != 1000:
raise item.error
return
elif item.type in (EventType.identify, EventType.resume):
await item.shard.reidentify(item.error)
@ -470,18 +537,21 @@ class AutoShardedClient(Client):
Closes the connection to Discord.
"""
if self.is_closed():
return
if self._closing_task:
return await self._closing_task
async def _close():
await self._connection.close()
self._closed = True
await self._connection.close()
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
self._closing_task = asyncio.create_task(_close())
await self._closing_task
async def change_presence(
self,

359
discord/sku.py

@ -0,0 +1,359 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import AsyncIterator, Optional, TYPE_CHECKING
from datetime import datetime
from . import utils
from .enums import try_enum, SKUType, EntitlementType
from .flags import SKUFlags
from .object import Object
from .subscription import Subscription
if TYPE_CHECKING:
from .abc import SnowflakeTime, Snowflake
from .guild import Guild
from .state import ConnectionState
from .types.sku import (
SKU as SKUPayload,
Entitlement as EntitlementPayload,
)
from .user import User
__all__ = (
'SKU',
'Entitlement',
)
class SKU:
"""Represents a premium offering as a stock-keeping unit (SKU).
.. versionadded:: 2.4
Attributes
-----------
id: :class:`int`
The SKU's ID.
type: :class:`SKUType`
The type of the SKU.
application_id: :class:`int`
The ID of the application that the SKU belongs to.
name: :class:`str`
The consumer-facing name of the premium offering.
slug: :class:`str`
A system-generated URL slug based on the SKU name.
"""
__slots__ = (
'_state',
'id',
'type',
'application_id',
'name',
'slug',
'_flags',
)
def __init__(self, *, state: ConnectionState, data: SKUPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.type: SKUType = try_enum(SKUType, data['type'])
self.application_id: int = int(data['application_id'])
self.name: str = data['name']
self.slug: str = data['slug']
self._flags: int = data['flags']
def __repr__(self) -> str:
return f'<SKU id={self.id} name={self.name!r} slug={self.slug!r}>'
@property
def flags(self) -> SKUFlags:
""":class:`SKUFlags`: Returns the flags of the SKU."""
return SKUFlags._from_value(self._flags)
@property
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the sku's creation time in UTC."""
return utils.snowflake_time(self.id)
async def fetch_subscription(self, subscription_id: int, /) -> Subscription:
"""|coro|
Retrieves a :class:`.Subscription` with the specified ID.
.. versionadded:: 2.5
Parameters
-----------
subscription_id: :class:`int`
The subscription's ID to fetch from.
Raises
-------
NotFound
An subscription with this ID does not exist.
HTTPException
Fetching the subscription failed.
Returns
--------
:class:`.Subscription`
The subscription you requested.
"""
data = await self._state.http.get_sku_subscription(self.id, subscription_id)
return Subscription(data=data, state=self._state)
async def subscriptions(
self,
*,
limit: Optional[int] = 50,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
user: Snowflake,
) -> AsyncIterator[Subscription]:
"""Retrieves an :term:`asynchronous iterator` of the :class:`.Subscription` that SKU has.
.. versionadded:: 2.5
Examples
---------
Usage ::
async for subscription in sku.subscriptions(limit=100, user=user):
print(subscription.user_id, subscription.current_period_end)
Flattening into a list ::
subscriptions = [subscription async for subscription in sku.subscriptions(limit=100, user=user)]
# subscriptions is now a list of Subscription...
All parameters are optional.
Parameters
-----------
limit: Optional[:class:`int`]
The number of subscriptions to retrieve. If ``None``, it retrieves every subscription for this SKU.
Note, however, that this would make it a slow operation. Defaults to ``100``.
before: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]]
Retrieve subscriptions before this date or entitlement.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
after: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]]
Retrieve subscriptions after this date or entitlement.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
user: :class:`~discord.abc.Snowflake`
The user to filter by.
Raises
-------
HTTPException
Fetching the subscriptions failed.
TypeError
Both ``after`` and ``before`` were provided, as Discord does not
support this type of pagination.
Yields
--------
:class:`.Subscription`
The subscription with the SKU.
"""
if before is not None and after is not None:
raise TypeError('subscriptions pagination does not support both before and after')
# This endpoint paginates in ascending order.
endpoint = self._state.http.list_sku_subscriptions
async def _before_strategy(retrieve: int, before: Optional[Snowflake], limit: Optional[int]):
before_id = before.id if before else None
data = await endpoint(self.id, before=before_id, limit=retrieve, user_id=user.id)
if data:
if limit is not None:
limit -= len(data)
before = Object(id=int(data[0]['id']))
return data, before, limit
async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]):
after_id = after.id if after else None
data = await endpoint(
self.id,
after=after_id,
limit=retrieve,
user_id=user.id,
)
if data:
if limit is not None:
limit -= len(data)
after = Object(id=int(data[-1]['id']))
return data, after, limit
if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if before:
strategy, state = _before_strategy, before
else:
strategy, state = _after_strategy, after
while True:
retrieve = 100 if limit is None else min(limit, 100)
if retrieve < 1:
return
data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 100:
limit = 0
for e in data:
yield Subscription(data=e, state=self._state)
class Entitlement:
"""Represents an entitlement from user or guild which has been granted access to a premium offering.
.. versionadded:: 2.4
Attributes
-----------
id: :class:`int`
The entitlement's ID.
sku_id: :class:`int`
The ID of the SKU that the entitlement belongs to.
application_id: :class:`int`
The ID of the application that the entitlement belongs to.
user_id: Optional[:class:`int`]
The ID of the user that is granted access to the entitlement.
type: :class:`EntitlementType`
The type of the entitlement.
deleted: :class:`bool`
Whether the entitlement has been deleted.
starts_at: Optional[:class:`datetime.datetime`]
A UTC start date which the entitlement is valid. Not present when using test entitlements.
ends_at: Optional[:class:`datetime.datetime`]
A UTC date which entitlement is no longer valid. Not present when using test entitlements.
guild_id: Optional[:class:`int`]
The ID of the guild that is granted access to the entitlement
consumed: :class:`bool`
For consumable items, whether the entitlement has been consumed.
"""
__slots__ = (
'_state',
'id',
'sku_id',
'application_id',
'user_id',
'type',
'deleted',
'starts_at',
'ends_at',
'guild_id',
'consumed',
)
def __init__(self, state: ConnectionState, data: EntitlementPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.sku_id: int = int(data['sku_id'])
self.application_id: int = int(data['application_id'])
self.user_id: Optional[int] = utils._get_as_snowflake(data, 'user_id')
self.type: EntitlementType = try_enum(EntitlementType, data['type'])
self.deleted: bool = data['deleted']
self.starts_at: Optional[datetime] = utils.parse_time(data.get('starts_at', None))
self.ends_at: Optional[datetime] = utils.parse_time(data.get('ends_at', None))
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.consumed: bool = data.get('consumed', False)
def __repr__(self) -> str:
return f'<Entitlement id={self.id} type={self.type!r} user_id={self.user_id}>'
@property
def user(self) -> Optional[User]:
"""Optional[:class:`User`]: The user that is granted access to the entitlement."""
if self.user_id is None:
return None
return self._state.get_user(self.user_id)
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that is granted access to the entitlement."""
return self._state._get_guild(self.guild_id)
@property
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the entitlement's creation time in UTC."""
return utils.snowflake_time(self.id)
def is_expired(self) -> bool:
""":class:`bool`: Returns ``True`` if the entitlement is expired. Will be always False for test entitlements."""
if self.ends_at is None:
return False
return utils.utcnow() >= self.ends_at
async def consume(self) -> None:
"""|coro|
Marks a one-time purchase entitlement as consumed.
Raises
-------
NotFound
The entitlement could not be found.
HTTPException
Consuming the entitlement failed.
"""
await self._state.http.consume_entitlement(self.application_id, self.id)
async def delete(self) -> None:
"""|coro|
Deletes the entitlement.
Raises
-------
NotFound
The entitlement could not be found.
HTTPException
Deleting the entitlement failed.
"""
await self._state.http.delete_entitlement(self.application_id, self.id)

325
discord/soundboard.py

@ -0,0 +1,325 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from . import utils
from .mixins import Hashable
from .partial_emoji import PartialEmoji, _EmojiTag
from .user import User
from .utils import MISSING
from .asset import Asset, AssetMixin
if TYPE_CHECKING:
import datetime
from typing import Dict, Any
from .types.soundboard import (
BaseSoundboardSound as BaseSoundboardSoundPayload,
SoundboardDefaultSound as SoundboardDefaultSoundPayload,
SoundboardSound as SoundboardSoundPayload,
)
from .state import ConnectionState
from .guild import Guild
from .message import EmojiInputType
__all__ = ('BaseSoundboardSound', 'SoundboardDefaultSound', 'SoundboardSound')
class BaseSoundboardSound(Hashable, AssetMixin):
"""Represents a generic Discord soundboard sound.
.. versionadded:: 2.5
.. container:: operations
.. describe:: x == y
Checks if two sounds are equal.
.. describe:: x != y
Checks if two sounds are not equal.
.. describe:: hash(x)
Returns the sound's hash.
Attributes
------------
id: :class:`int`
The ID of the sound.
volume: :class:`float`
The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%).
"""
__slots__ = ('_state', 'id', 'volume')
def __init__(self, *, state: ConnectionState, data: BaseSoundboardSoundPayload):
self._state: ConnectionState = state
self.id: int = int(data['sound_id'])
self._update(data)
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.id == other.id
return NotImplemented
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def _update(self, data: BaseSoundboardSoundPayload):
self.volume: float = data['volume']
@property
def url(self) -> str:
""":class:`str`: Returns the URL of the sound."""
return f'{Asset.BASE}/soundboard-sounds/{self.id}'
class SoundboardDefaultSound(BaseSoundboardSound):
"""Represents a Discord soundboard default sound.
.. versionadded:: 2.5
.. container:: operations
.. describe:: x == y
Checks if two sounds are equal.
.. describe:: x != y
Checks if two sounds are not equal.
.. describe:: hash(x)
Returns the sound's hash.
Attributes
------------
id: :class:`int`
The ID of the sound.
volume: :class:`float`
The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%).
name: :class:`str`
The name of the sound.
emoji: :class:`PartialEmoji`
The emoji of the sound.
"""
__slots__ = ('name', 'emoji')
def __init__(self, *, state: ConnectionState, data: SoundboardDefaultSoundPayload):
self.name: str = data['name']
self.emoji: PartialEmoji = PartialEmoji(name=data['emoji_name'])
super().__init__(state=state, data=data)
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('volume', self.volume),
('emoji', self.emoji),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f"<{self.__class__.__name__} {inner}>"
class SoundboardSound(BaseSoundboardSound):
"""Represents a Discord soundboard sound.
.. versionadded:: 2.5
.. container:: operations
.. describe:: x == y
Checks if two sounds are equal.
.. describe:: x != y
Checks if two sounds are not equal.
.. describe:: hash(x)
Returns the sound's hash.
Attributes
------------
id: :class:`int`
The ID of the sound.
volume: :class:`float`
The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%).
name: :class:`str`
The name of the sound.
emoji: Optional[:class:`PartialEmoji`]
The emoji of the sound. ``None`` if no emoji is set.
guild: :class:`Guild`
The guild in which the sound is uploaded.
available: :class:`bool`
Whether this sound is available for use.
"""
__slots__ = ('_state', 'name', 'emoji', '_user', 'available', '_user_id', 'guild')
def __init__(self, *, guild: Guild, state: ConnectionState, data: SoundboardSoundPayload):
super().__init__(state=state, data=data)
self.guild = guild
self._user_id = utils._get_as_snowflake(data, 'user_id')
self._user = data.get('user')
self._update(data)
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('volume', self.volume),
('emoji', self.emoji),
('user', self.user),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f"<{self.__class__.__name__} {inner}>"
def _update(self, data: SoundboardSoundPayload):
super()._update(data)
self.name: str = data['name']
self.emoji: Optional[PartialEmoji] = None
emoji_id = utils._get_as_snowflake(data, 'emoji_id')
emoji_name = data['emoji_name']
if emoji_id is not None or emoji_name is not None:
self.emoji = PartialEmoji(id=emoji_id, name=emoji_name) # type: ignore # emoji_name cannot be None here
self.available: bool = data['available']
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the snowflake's creation time in UTC."""
return utils.snowflake_time(self.id)
@property
def user(self) -> Optional[User]:
"""Optional[:class:`User`]: The user who uploaded the sound."""
if self._user is None:
if self._user_id is None:
return None
return self._state.get_user(self._user_id)
return User(state=self._state, data=self._user)
async def edit(
self,
*,
name: str = MISSING,
volume: Optional[float] = MISSING,
emoji: Optional[EmojiInputType] = MISSING,
reason: Optional[str] = None,
):
"""|coro|
Edits the soundboard sound.
You must have :attr:`~Permissions.manage_expressions` to edit the sound.
If the sound was created by the client, you must have either :attr:`~Permissions.manage_expressions`
or :attr:`~Permissions.create_expressions`.
Parameters
----------
name: :class:`str`
The new name of the sound. Must be between 2 and 32 characters.
volume: Optional[:class:`float`]
The new volume of the sound. Must be between 0 and 1.
emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`]]
The new emoji of the sound.
reason: Optional[:class:`str`]
The reason for editing this sound. Shows up on the audit log.
Raises
-------
Forbidden
You do not have permissions to edit the soundboard sound.
HTTPException
Editing the soundboard sound failed.
Returns
-------
:class:`SoundboardSound`
The newly updated soundboard sound.
"""
payload: Dict[str, Any] = {}
if name is not MISSING:
payload['name'] = name
if volume is not MISSING:
payload['volume'] = volume
if emoji is not MISSING:
if emoji is None:
payload['emoji_id'] = None
payload['emoji_name'] = None
else:
if isinstance(emoji, _EmojiTag):
partial_emoji = emoji._to_partial()
elif isinstance(emoji, str):
partial_emoji = PartialEmoji.from_str(emoji)
else:
partial_emoji = None
if partial_emoji is not None:
if partial_emoji.id is None:
payload['emoji_name'] = partial_emoji.name
else:
payload['emoji_id'] = partial_emoji.id
data = await self._state.http.edit_soundboard_sound(self.guild.id, self.id, reason=reason, **payload)
return SoundboardSound(guild=self.guild, state=self._state, data=data)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the soundboard sound.
You must have :attr:`~Permissions.manage_expressions` to delete the sound.
If the sound was created by the client, you must have either :attr:`~Permissions.manage_expressions`
or :attr:`~Permissions.create_expressions`.
Parameters
-----------
reason: Optional[:class:`str`]
The reason for deleting this sound. Shows up on the audit log.
Raises
-------
Forbidden
You do not have permissions to delete the soundboard sound.
HTTPException
Deleting the soundboard sound failed.
"""
await self._state.http.delete_soundboard_sound(self.guild.id, self.id, reason=reason)

298
discord/state.py

@ -32,6 +32,7 @@ from typing import (
Dict,
Optional,
TYPE_CHECKING,
Type,
Union,
Callable,
Any,
@ -52,6 +53,7 @@ import os
from .guild import Guild
from .activity import BaseActivity
from .sku import Entitlement
from .user import User, ClientUser
from .emoji import Emoji
from .mentions import AllowedMentions
@ -60,6 +62,7 @@ from .message import Message
from .channel import *
from .channel import _channel_factory
from .raw_models import *
from .presences import RawPresenceUpdateEvent
from .member import Member
from .role import Role
from .enums import ChannelType, try_enum, Status
@ -76,6 +79,9 @@ from .sticker import GuildSticker
from .automod import AutoModRule, AutoModAction
from .audit_logs import AuditLogEntry
from ._types import ClientT
from .soundboard import SoundboardSound
from .subscription import Subscription
if TYPE_CHECKING:
from .abc import PrivateChannel
@ -84,7 +90,10 @@ if TYPE_CHECKING:
from .http import HTTPClient
from .voice_client import VoiceProtocol
from .gateway import DiscordWebSocket
from .ui.item import Item
from .ui.dynamic import DynamicItem
from .app_commands import CommandTree, Translator
from .poll import Poll
from .types.automod import AutoModerationRule, AutoModerationActionExecution
from .types.snowflake import Snowflake
@ -106,12 +115,14 @@ class ChunkRequest:
def __init__(
self,
guild_id: int,
shard_id: int,
loop: asyncio.AbstractEventLoop,
resolver: Callable[[int], Any],
*,
cache: bool = True,
) -> None:
self.guild_id: int = guild_id
self.shard_id: int = shard_id
self.resolver: Callable[[int], Any] = resolver
self.loop: asyncio.AbstractEventLoop = loop
self.cache: bool = cache
@ -251,6 +262,10 @@ class ConnectionState(Generic[ClientT]):
if not intents.members or cache_flags._empty:
self.store_user = self.store_user_no_intents
self.raw_presence_flag: bool = options.get('enable_raw_presences', utils.MISSING)
if self.raw_presence_flag is utils.MISSING:
self.raw_presence_flag = not intents.members and intents.presences
self.parsers: Dict[str, Callable[[Any], None]]
self.parsers = parsers = {}
for attr, func in inspect.getmembers(self):
@ -259,6 +274,13 @@ class ConnectionState(Generic[ClientT]):
self.clear()
# For some reason Discord still sends emoji/sticker data in payloads
# This makes it hard to actually swap out the appropriate store methods
# So this is checked instead, it's a small penalty to pay
@property
def cache_guild_expressions(self) -> bool:
return self._intents.emojis_and_stickers
async def close(self) -> None:
for voice in self.voice_clients:
try:
@ -304,6 +326,16 @@ class ConnectionState(Generic[ClientT]):
for key in removed:
del self._chunk_requests[key]
def clear_chunk_requests(self, shard_id: int | None) -> None:
removed = []
for key, request in self._chunk_requests.items():
if shard_id is None or request.shard_id == shard_id:
request.done()
removed.append(key)
for key in removed:
del self._chunk_requests[key]
def call_handlers(self, key: str, *args: Any, **kwargs: Any) -> None:
try:
func = self.handlers[key]
@ -349,18 +381,18 @@ class ConnectionState(Generic[ClientT]):
for vc in self.voice_clients:
vc.main_ws = ws # type: ignore # Silencing the unknown attribute (ok at runtime).
def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
def store_user(self, data: Union[UserPayload, PartialUserPayload], *, cache: bool = True) -> User:
# this way is 300% faster than `dict.setdefault`.
user_id = int(data['id'])
try:
return self._users[user_id]
except KeyError:
user = User(state=self, data=data)
if user.discriminator != '0000':
if cache:
self._users[user_id] = user
return user
def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload]) -> User:
def store_user_no_intents(self, data: Union[UserPayload, PartialUserPayload], *, cache: bool = True) -> User:
return User(state=self, data=data)
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
@ -388,6 +420,12 @@ class ConnectionState(Generic[ClientT]):
def prevent_view_updates_for(self, message_id: int) -> Optional[View]:
return self._view_store.remove_message_tracking(message_id)
def store_dynamic_items(self, *items: Type[DynamicItem[Item[Any]]]) -> None:
self._view_store.add_dynamic_items(*items)
def remove_dynamic_items(self, *items: Type[DynamicItem[Item[Any]]]) -> None:
self._view_store.remove_dynamic_items(*items)
@property
def persistent_views(self) -> Sequence[View]:
return self._view_store.persistent_views
@ -400,8 +438,8 @@ class ConnectionState(Generic[ClientT]):
# the keys of self._guilds are ints
return self._guilds.get(guild_id) # type: ignore
def _get_or_create_unavailable_guild(self, guild_id: int) -> Guild:
return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id)
def _get_or_create_unavailable_guild(self, guild_id: int, *, data: Optional[Dict[str, Any]] = None) -> Guild:
return self._guilds.get(guild_id) or Guild._create_unavailable(state=self, guild_id=guild_id, data=data)
def _add_guild(self, guild: Guild) -> None:
self._guilds[guild.id] = guild
@ -425,6 +463,14 @@ class ConnectionState(Generic[ClientT]):
def stickers(self) -> Sequence[GuildSticker]:
return utils.SequenceProxy(self._stickers.values())
@property
def soundboard_sounds(self) -> List[SoundboardSound]:
all_sounds = []
for guild in self.guilds:
all_sounds.extend(guild.soundboard_sounds)
return all_sounds
def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]:
# the keys of self._emojis are ints
return self._emojis.get(emoji_id) # type: ignore
@ -494,7 +540,7 @@ class ConnectionState(Generic[ClientT]):
) -> Tuple[Union[Channel, Thread], Optional[Guild]]:
channel_id = int(data['channel_id'])
try:
guild_id = guild_id or int(data['guild_id'])
guild_id = guild_id or int(data['guild_id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
guild = self._get_guild(guild_id)
except KeyError:
channel = DMChannel._from_message(self, channel_id)
@ -504,6 +550,34 @@ class ConnectionState(Generic[ClientT]):
return channel or PartialMessageable(state=self, guild_id=guild_id, id=channel_id), guild
def _update_poll_counts(self, message: Message, answer_id: int, added: bool, self_voted: bool = False) -> Optional[Poll]:
poll = message.poll
if not poll:
return
poll._handle_vote(answer_id, added, self_voted)
return poll
def _update_poll_results(self, from_: Message, to: Union[Message, int]) -> None:
if isinstance(to, Message):
cached = self._get_message(to.id)
elif isinstance(to, int):
cached = self._get_message(to)
if cached is None:
return
to = cached
else:
return
if to.poll is None:
return
to.poll._update_results_from_message(from_)
if cached is not None and cached.poll:
cached.poll._update_results_from_message(from_)
async def chunker(
self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None
) -> None:
@ -518,7 +592,7 @@ class ConnectionState(Generic[ClientT]):
if ws is None:
raise RuntimeError('Somehow do not have a websocket for this guild_id')
request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
request = ChunkRequest(guild.id, guild.shard_id, self.loop, self._get_guild, cache=cache)
self._chunk_requests[request.nonce] = request
try:
@ -585,6 +659,7 @@ class ConnectionState(Generic[ClientT]):
self._ready_state: asyncio.Queue[Guild] = asyncio.Queue()
self.clear(views=False)
self.clear_chunk_requests(None)
self.user = user = ClientUser(state=self, data=data['user'])
self._users[user.id] = user # type: ignore
@ -614,7 +689,7 @@ class ConnectionState(Generic[ClientT]):
if self._messages is not None:
self._messages.append(message)
# we ensure that the channel is either a TextChannel, VoiceChannel, or Thread
if channel and channel.__class__ in (TextChannel, VoiceChannel, Thread):
if channel and channel.__class__ in (TextChannel, VoiceChannel, Thread, StageChannel):
channel.last_message_id = message.id # type: ignore
def parse_message_delete(self, data: gw.MessageDeleteEvent) -> None:
@ -641,23 +716,27 @@ class ConnectionState(Generic[ClientT]):
self._messages.remove(msg) # type: ignore
def parse_message_update(self, data: gw.MessageUpdateEvent) -> None:
raw = RawMessageUpdateEvent(data)
message = self._get_message(raw.message_id)
if message is not None:
older_message = copy.copy(message)
channel, _ = self._get_guild_channel(data)
# channel would be the correct type here
updated_message = Message(channel=channel, data=data, state=self) # type: ignore
raw = RawMessageUpdateEvent(data=data, message=updated_message)
cached_message = self._get_message(updated_message.id)
if cached_message is not None:
older_message = copy.copy(cached_message)
raw.cached_message = older_message
self.dispatch('raw_message_edit', raw)
message._update(data)
cached_message._update(data)
# Coerce the `after` parameter to take the new updated Member
# ref: #5999
older_message.author = message.author
self.dispatch('message_edit', older_message, message)
older_message.author = updated_message.author
self.dispatch('message_edit', older_message, updated_message)
else:
self.dispatch('raw_message_edit', raw)
if 'components' in data:
try:
entity_id = int(data['interaction']['id'])
entity_id = int(data['interaction']['id']) # pyright: ignore[reportTypedDictNotRequiredAccess]
except (KeyError, ValueError):
entity_id = raw.message_id
@ -753,22 +832,24 @@ class ConnectionState(Generic[ClientT]):
self.dispatch('interaction', interaction)
def parse_presence_update(self, data: gw.PresenceUpdateEvent) -> None:
guild_id = utils._get_as_snowflake(data, 'guild_id')
# guild_id won't be None here
guild = self._get_guild(guild_id)
if guild is None:
_log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
raw = RawPresenceUpdateEvent(data=data, state=self)
if self.raw_presence_flag:
self.dispatch('raw_presence_update', raw)
if raw.guild is None:
_log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', raw.guild_id)
return
user = data['user']
member_id = int(user['id'])
member = guild.get_member(member_id)
member = raw.guild.get_member(raw.user_id)
if member is None:
_log.debug('PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding', member_id)
_log.debug('PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding', raw.user_id)
return
old_member = Member._copy(member)
user_update = member._presence_update(data=data, user=user)
user_update = member._presence_update(raw=raw, user=data['user'])
if user_update:
self.dispatch('user_update', user_update[0], user_update[1])
@ -801,6 +882,12 @@ class ConnectionState(Generic[ClientT]):
guild._scheduled_events.pop(s.id)
self.dispatch('scheduled_event_delete', s)
threads = guild._remove_threads_by_channel(channel_id)
for thread in threads:
self.dispatch('thread_delete', thread)
self.dispatch('raw_thread_delete', RawThreadDeleteEvent._from_thread(thread))
def parse_channel_update(self, data: gw.ChannelUpdateEvent) -> None:
channel_type = try_enum(ChannelType, data.get('type'))
channel_id = int(data['id'])
@ -848,7 +935,7 @@ class ConnectionState(Generic[ClientT]):
def parse_channel_pins_update(self, data: gw.ChannelPinsUpdateEvent) -> None:
channel_id = int(data['channel_id'])
try:
guild = self._get_guild(int(data['guild_id']))
guild = self._get_guild(int(data['guild_id'])) # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
guild = None
channel = self._get_private_channel(channel_id)
@ -930,7 +1017,7 @@ class ConnectionState(Generic[ClientT]):
return
try:
channel_ids = {int(i) for i in data['channel_ids']}
channel_ids = {int(i) for i in data['channel_ids']} # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
# If not provided, then the entire guild is being synced
# So all previous thread data should be overwritten
@ -1108,6 +1195,7 @@ class ConnectionState(Generic[ClientT]):
integrations={},
app_commands={},
automod_rules={},
webhooks={},
data=data,
guild=guild,
)
@ -1186,7 +1274,9 @@ class ConnectionState(Generic[ClientT]):
cache = cache or self.member_cache_flags.joined
request = self._chunk_requests.get(guild.id)
if request is None:
self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
self._chunk_requests[guild.id] = request = ChunkRequest(
guild.id, guild.shard_id, self.loop, self._get_guild, cache=cache
)
await self.chunker(guild.id, nonce=request.nonce)
if wait:
@ -1347,8 +1437,10 @@ class ConnectionState(Generic[ClientT]):
user = presence['user']
member_id = user['id']
member = member_dict.get(member_id)
if member is not None:
member._presence_update(presence, user)
raw_presence = RawPresenceUpdateEvent(data=presence, state=self)
member._presence_update(raw_presence, user)
complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count')
self.process_chunk_requests(guild_id, data.get('nonce'), members, complete)
@ -1461,12 +1553,8 @@ class ConnectionState(Generic[ClientT]):
def parse_guild_scheduled_event_delete(self, data: gw.GuildScheduledEventDeleteEvent) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
try:
scheduled_event = guild._scheduled_events.pop(int(data['id']))
except KeyError:
pass
else:
self.dispatch('scheduled_event_delete', scheduled_event)
scheduled_event = guild._scheduled_events.pop(int(data['id']), ScheduledEvent(state=self, data=data))
self.dispatch('scheduled_event_delete', scheduled_event)
else:
_log.debug('SCHEDULED_EVENT_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
@ -1508,6 +1596,63 @@ class ConnectionState(Generic[ClientT]):
else:
_log.debug('SCHEDULED_EVENT_USER_REMOVE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_guild_soundboard_sound_create(self, data: gw.GuildSoundBoardSoundCreateEvent) -> None:
guild_id = int(data['guild_id']) # type: ignore # can't be None here
guild = self._get_guild(guild_id)
if guild is not None:
sound = SoundboardSound(guild=guild, state=self, data=data)
guild._add_soundboard_sound(sound)
self.dispatch('soundboard_sound_create', sound)
else:
_log.debug('GUILD_SOUNDBOARD_SOUND_CREATE referencing unknown guild ID: %s. Discarding.', guild_id)
def _update_and_dispatch_sound_update(self, sound: SoundboardSound, data: gw.GuildSoundBoardSoundUpdateEvent):
old_sound = copy.copy(sound)
sound._update(data)
self.dispatch('soundboard_sound_update', old_sound, sound)
def parse_guild_soundboard_sound_update(self, data: gw.GuildSoundBoardSoundUpdateEvent) -> None:
guild_id = int(data['guild_id']) # type: ignore # can't be None here
guild = self._get_guild(guild_id)
if guild is not None:
sound_id = int(data['sound_id'])
sound = guild.get_soundboard_sound(sound_id)
if sound is not None:
self._update_and_dispatch_sound_update(sound, data)
else:
_log.warning('GUILD_SOUNDBOARD_SOUND_UPDATE referencing unknown sound ID: %s. Discarding.', sound_id)
else:
_log.debug('GUILD_SOUNDBOARD_SOUND_UPDATE referencing unknown guild ID: %s. Discarding.', guild_id)
def parse_guild_soundboard_sound_delete(self, data: gw.GuildSoundBoardSoundDeleteEvent) -> None:
guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id)
if guild is not None:
sound_id = int(data['sound_id'])
sound = guild.get_soundboard_sound(sound_id)
if sound is not None:
guild._remove_soundboard_sound(sound)
self.dispatch('soundboard_sound_delete', sound)
else:
_log.warning('GUILD_SOUNDBOARD_SOUND_DELETE referencing unknown sound ID: %s. Discarding.', sound_id)
else:
_log.debug('GUILD_SOUNDBOARD_SOUND_DELETE referencing unknown guild ID: %s. Discarding.', guild_id)
def parse_guild_soundboard_sounds_update(self, data: gw.GuildSoundBoardSoundsUpdateEvent) -> None:
guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id)
if guild is None:
_log.debug('GUILD_SOUNDBOARD_SOUNDS_UPDATE referencing unknown guild ID: %s. Discarding.', guild_id)
return
for raw_sound in data['soundboard_sounds']:
sound_id = int(raw_sound['sound_id'])
sound = guild.get_soundboard_sound(sound_id)
if sound is not None:
self._update_and_dispatch_sound_update(sound, raw_sound)
else:
_log.warning('GUILD_SOUNDBOARD_SOUNDS_UPDATE referencing unknown sound ID: %s. Discarding.', sound_id)
def parse_application_command_permissions_update(self, data: GuildApplicationCommandPermissionsPayload):
raw = RawAppCommandPermissionsUpdateEvent(data=data, state=self)
self.dispatch('raw_app_command_permissions_update', raw)
@ -1538,6 +1683,14 @@ class ConnectionState(Generic[ClientT]):
else:
_log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id'])
def parse_voice_channel_effect_send(self, data: gw.VoiceChannelEffectSendEvent):
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
effect = VoiceChannelEffect(state=self, data=data, guild=guild)
self.dispatch('voice_channel_effect', effect)
else:
_log.debug('VOICE_CHANNEL_EFFECT_SEND referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_voice_server_update(self, data: gw.VoiceServerUpdateEvent) -> None:
key_id = int(data['guild_id'])
@ -1553,7 +1706,8 @@ class ConnectionState(Generic[ClientT]):
if channel is not None:
if isinstance(channel, DMChannel):
channel.recipient = raw.user
if raw.user is not None and raw.user not in channel.recipients:
channel.recipients.append(raw.user)
elif guild is not None:
raw.user = guild.get_member(raw.user_id)
@ -1567,6 +1721,66 @@ class ConnectionState(Generic[ClientT]):
self.dispatch('raw_typing', raw)
def parse_entitlement_create(self, data: gw.EntitlementCreateEvent) -> None:
entitlement = Entitlement(data=data, state=self)
self.dispatch('entitlement_create', entitlement)
def parse_entitlement_update(self, data: gw.EntitlementUpdateEvent) -> None:
entitlement = Entitlement(data=data, state=self)
self.dispatch('entitlement_update', entitlement)
def parse_entitlement_delete(self, data: gw.EntitlementDeleteEvent) -> None:
entitlement = Entitlement(data=data, state=self)
self.dispatch('entitlement_delete', entitlement)
def parse_message_poll_vote_add(self, data: gw.PollVoteActionEvent) -> None:
raw = RawPollVoteActionEvent(data)
self.dispatch('raw_poll_vote_add', raw)
message = self._get_message(raw.message_id)
guild = self._get_guild(raw.guild_id)
if guild:
user = guild.get_member(raw.user_id)
else:
user = self.get_user(raw.user_id)
if message and user:
poll = self._update_poll_counts(message, raw.answer_id, True, raw.user_id == self.self_id)
if poll:
self.dispatch('poll_vote_add', user, poll.get_answer(raw.answer_id))
def parse_message_poll_vote_remove(self, data: gw.PollVoteActionEvent) -> None:
raw = RawPollVoteActionEvent(data)
self.dispatch('raw_poll_vote_remove', raw)
message = self._get_message(raw.message_id)
guild = self._get_guild(raw.guild_id)
if guild:
user = guild.get_member(raw.user_id)
else:
user = self.get_user(raw.user_id)
if message and user:
poll = self._update_poll_counts(message, raw.answer_id, False, raw.user_id == self.self_id)
if poll:
self.dispatch('poll_vote_remove', user, poll.get_answer(raw.answer_id))
def parse_subscription_create(self, data: gw.SubscriptionCreateEvent) -> None:
subscription = Subscription(data=data, state=self)
self.dispatch('subscription_create', subscription)
def parse_subscription_update(self, data: gw.SubscriptionUpdateEvent) -> None:
subscription = Subscription(data=data, state=self)
self.dispatch('subscription_update', subscription)
def parse_subscription_delete(self, data: gw.SubscriptionDeleteEvent) -> None:
subscription = Subscription(data=data, state=self)
self.dispatch('subscription_delete', subscription)
def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]:
if isinstance(channel, (TextChannel, Thread, VoiceChannel)):
return channel.guild.get_member(user_id)
@ -1611,6 +1825,15 @@ class ConnectionState(Generic[ClientT]):
def create_message(self, *, channel: MessageableChannel, data: MessagePayload) -> Message:
return Message(state=self, channel=channel, data=data)
def get_soundboard_sound(self, id: Optional[int]) -> Optional[SoundboardSound]:
if id is None:
return
for guild in self.guilds:
sound = guild._resolve_soundboard_sound(id)
if sound is not None:
return sound
class AutoShardedConnectionState(ConnectionState[ClientT]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
@ -1721,6 +1944,7 @@ class AutoShardedConnectionState(ConnectionState[ClientT]):
if shard_id in self._ready_tasks:
self._ready_tasks[shard_id].cancel()
self.clear_chunk_requests(shard_id)
if shard_id not in self._ready_states:
self._ready_states[shard_id] = asyncio.Queue()

32
discord/sticker.py

@ -28,8 +28,7 @@ import unicodedata
from .mixins import Hashable
from .asset import Asset, AssetMixin
from .utils import cached_slot_property, find, snowflake_time, get, MISSING, _get_as_snowflake
from .errors import InvalidData
from .utils import cached_slot_property, snowflake_time, get, MISSING, _get_as_snowflake
from .enums import StickerType, StickerFormatType, try_enum
__all__ = (
@ -51,7 +50,6 @@ if TYPE_CHECKING:
Sticker as StickerPayload,
StandardSticker as StandardStickerPayload,
GuildSticker as GuildStickerPayload,
ListPremiumStickerPacks as ListPremiumStickerPacksPayload,
)
@ -203,7 +201,10 @@ class StickerItem(_StickerTag):
self.name: str = data['name']
self.id: int = int(data['id'])
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
if self.format is StickerFormatType.gif:
self.url: str = f'https://media.discordapp.net/stickers/{self.id}.gif'
else:
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
def __repr__(self) -> str:
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
@ -258,8 +259,6 @@ class Sticker(_StickerTag):
The id of the sticker.
description: :class:`str`
The description of the sticker.
pack_id: :class:`int`
The id of the sticker's pack.
format: :class:`StickerFormatType`
The format for the sticker's image.
url: :class:`str`
@ -277,7 +276,10 @@ class Sticker(_StickerTag):
self.name: str = data['name']
self.description: str = data['description']
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
if self.format is StickerFormatType.gif:
self.url: str = f'https://media.discordapp.net/stickers/{self.id}.gif'
else:
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
def __repr__(self) -> str:
return f'<Sticker id={self.id} name={self.name!r}>'
@ -349,9 +351,12 @@ class StandardSticker(Sticker):
Retrieves the sticker pack that this sticker belongs to.
.. versionchanged:: 2.5
Now raises ``NotFound`` instead of ``InvalidData``.
Raises
--------
InvalidData
NotFound
The corresponding sticker pack was not found.
HTTPException
Retrieving the sticker pack failed.
@ -361,13 +366,8 @@ class StandardSticker(Sticker):
:class:`StickerPack`
The retrieved sticker pack.
"""
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
packs = data['sticker_packs']
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
if pack:
return StickerPack(state=self._state, data=pack)
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
data = await self._state.http.get_sticker_pack(self.pack_id)
return StickerPack(state=self._state, data=data)
class GuildSticker(Sticker):
@ -414,7 +414,7 @@ class GuildSticker(Sticker):
def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data)
self.available: bool = data['available']
self.available: bool = data.get('available', True)
self.guild_id: int = int(data['guild_id'])
user = data.get('user')
self.user: Optional[User] = self._state.store_user(user) if user else None

107
discord/subscription.py

@ -0,0 +1,107 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import datetime
from typing import List, Optional, TYPE_CHECKING
from . import utils
from .mixins import Hashable
from .enums import try_enum, SubscriptionStatus
if TYPE_CHECKING:
from .state import ConnectionState
from .types.subscription import Subscription as SubscriptionPayload
from .user import User
__all__ = ('Subscription',)
class Subscription(Hashable):
"""Represents a Discord subscription.
.. versionadded:: 2.5
Attributes
-----------
id: :class:`int`
The subscription's ID.
user_id: :class:`int`
The ID of the user that is subscribed.
sku_ids: List[:class:`int`]
The IDs of the SKUs that the user subscribed to.
entitlement_ids: List[:class:`int`]
The IDs of the entitlements granted for this subscription.
current_period_start: :class:`datetime.datetime`
When the current billing period started.
current_period_end: :class:`datetime.datetime`
When the current billing period ends.
status: :class:`SubscriptionStatus`
The status of the subscription.
canceled_at: Optional[:class:`datetime.datetime`]
When the subscription was canceled.
This is only available for subscriptions with a :attr:`status` of :attr:`SubscriptionStatus.inactive`.
renewal_sku_ids: List[:class:`int`]
The IDs of the SKUs that the user is going to be subscribed to when renewing.
"""
__slots__ = (
'_state',
'id',
'user_id',
'sku_ids',
'entitlement_ids',
'current_period_start',
'current_period_end',
'status',
'canceled_at',
'renewal_sku_ids',
)
def __init__(self, *, state: ConnectionState, data: SubscriptionPayload):
self._state = state
self.id: int = int(data['id'])
self.user_id: int = int(data['user_id'])
self.sku_ids: List[int] = list(map(int, data['sku_ids']))
self.entitlement_ids: List[int] = list(map(int, data['entitlement_ids']))
self.current_period_start: datetime.datetime = utils.parse_time(data['current_period_start'])
self.current_period_end: datetime.datetime = utils.parse_time(data['current_period_end'])
self.status: SubscriptionStatus = try_enum(SubscriptionStatus, data['status'])
self.canceled_at: Optional[datetime.datetime] = utils.parse_time(data['canceled_at'])
self.renewal_sku_ids: List[int] = list(map(int, data['renewal_sku_ids'] or []))
def __repr__(self) -> str:
return f'<Subscription id={self.id} user_id={self.user_id} status={self.status!r}>'
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the subscription's creation time in UTC."""
return utils.snowflake_time(self.id)
@property
def user(self) -> Optional[User]:
"""Optional[:class:`User`]: The user that is subscribed."""
return self._state.get_user(self.user_id)

21
discord/team.py

@ -27,7 +27,7 @@ from __future__ import annotations
from . import utils
from .user import BaseUser
from .asset import Asset
from .enums import TeamMembershipState, try_enum
from .enums import TeamMemberRole, TeamMembershipState, try_enum
from typing import TYPE_CHECKING, Optional, List
@ -108,7 +108,7 @@ class TeamMember(BaseUser):
.. describe:: str(x)
Returns the team member's name with discriminator.
Returns the team member's handle (e.g. ``name`` or ``name#discriminator``).
.. versionadded:: 1.3
@ -119,25 +119,34 @@ class TeamMember(BaseUser):
id: :class:`int`
The team member's unique ID.
discriminator: :class:`str`
The team member's discriminator. This is given when the username has conflicts.
The team member's discriminator. This is a legacy concept that is no longer used.
global_name: Optional[:class:`str`]
The team member's global nickname, taking precedence over the username in display.
.. versionadded:: 2.3
bot: :class:`bool`
Specifies if the user is a bot account.
team: :class:`Team`
The team that the member is from.
membership_state: :class:`TeamMembershipState`
The membership state of the member (e.g. invited or accepted)
role: :class:`TeamMemberRole`
The role of the member within the team.
.. versionadded:: 2.4
"""
__slots__ = ('team', 'membership_state', 'permissions')
__slots__ = ('team', 'membership_state', 'permissions', 'role')
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload) -> None:
self.team: Team = team
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state'])
self.permissions: List[str] = data['permissions']
self.permissions: List[str] = data.get('permissions', [])
self.role: TeamMemberRole = try_enum(TeamMemberRole, data['role'])
super().__init__(state=state, data=data['user'])
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>'
f'global_name={self.global_name!r} membership_state={self.membership_state!r}>'
)

27
discord/template.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Optional, TYPE_CHECKING, List
from .utils import parse_time, _bytes_to_base64_data, MISSING
from .utils import parse_time, _bytes_to_base64_data, MISSING, deprecated
from .guild import Guild
# fmt: off
@ -69,6 +69,10 @@ class _PartialTemplateState:
def member_cache_flags(self):
return self.__state.member_cache_flags
@property
def cache_guild_expressions(self):
return False
def store_emoji(self, guild, packet) -> None:
return None
@ -146,18 +150,11 @@ class Template:
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at'))
guild_id = int(data['source_guild_id'])
guild: Optional[Guild] = self._state._get_guild(guild_id)
self.source_guild: Guild
if guild is None:
source_serialised = data['serialized_source_guild']
source_serialised['id'] = guild_id
state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else:
self.source_guild = guild
source_serialised = data['serialized_source_guild']
source_serialised['id'] = int(data['source_guild_id'])
state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
self.is_dirty: Optional[bool] = data.get('is_dirty', None)
@ -167,6 +164,7 @@ class Template:
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
)
@deprecated()
async def create_guild(self, name: str, icon: bytes = MISSING) -> Guild:
"""|coro|
@ -181,6 +179,9 @@ class Template:
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
.. deprecated:: 2.6
This function is deprecated and will be removed in a future version.
Parameters
----------
name: :class:`str`

26
discord/threads.py

@ -121,8 +121,12 @@ class Thread(Messageable, Hashable):
This is always ``True`` for public threads.
archiver_id: Optional[:class:`int`]
The user's ID that archived this thread.
.. note::
Due to an API change, the ``archiver_id`` will always be ``None`` and can only be obtained via the audit log.
auto_archive_duration: :class:`int`
The duration in minutes until the thread is automatically archived due to inactivity.
The duration in minutes until the thread is automatically hidden from the channel list.
Usually a value of 60, 1440, 4320 and 10080.
archive_timestamp: :class:`datetime.datetime`
An aware timestamp of when the thread's archived status was last updated in UTC.
@ -188,7 +192,7 @@ class Thread(Messageable, Hashable):
self.me: Optional[ThreadMember]
try:
member = data['member']
member = data['member'] # pyright: ignore[reportTypedDictNotRequiredAccess]
except KeyError:
self.me = None
else:
@ -268,12 +272,12 @@ class Thread(Messageable, Hashable):
.. versionadded:: 2.1
"""
tags = []
if self.parent is None or self.parent.type != ChannelType.forum:
if self.parent is None or self.parent.type not in (ChannelType.forum, ChannelType.media):
return tags
parent = self.parent
for tag_id in self._applied_tags:
tag = parent.get_tag(tag_id)
tag = parent.get_tag(tag_id) # type: ignore # parent here will be ForumChannel instance
if tag is not None:
tags.append(tag)
@ -608,7 +612,7 @@ class Thread(Messageable, Hashable):
Whether non-moderators can add other non-moderators to this thread.
Only available for private threads.
auto_archive_duration: :class:`int`
The new duration in minutes before a thread is automatically archived for inactivity.
The new duration in minutes before a thread is automatically hidden from the channel list.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
slowmode_delay: :class:`int`
Specifies the slowmode rate limit for user in this thread, in seconds.
@ -846,13 +850,21 @@ class Thread(Messageable, Hashable):
members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members]
async def delete(self) -> None:
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes this thread.
You must have :attr:`~Permissions.manage_threads` to delete threads.
Parameters
-----------
reason: Optional[:class:`str`]
The reason for deleting this thread.
Shows up on the audit log.
.. versionadded:: 2.4
Raises
-------
Forbidden
@ -860,7 +872,7 @@ class Thread(Messageable, Hashable):
HTTPException
Deleting the thread failed.
"""
await self._state.http.delete_channel(self.id)
await self._state.http.delete_channel(self.id, reason=reason)
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.

1
discord/types/activity.py

@ -93,6 +93,7 @@ class Activity(_BaseActivity, total=False):
state: Optional[str]
details: Optional[str]
timestamps: ActivityTimestamps
platform: Optional[str]
assets: ActivityAssets
party: ActivityParty
application_id: Snowflake

31
discord/types/appinfo.py

@ -24,12 +24,13 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import TypedDict, List, Optional
from typing import Literal, Dict, TypedDict, List, Optional
from typing_extensions import NotRequired
from .user import User
from .team import Team
from .snowflake import Snowflake
from .emoji import Emoji
class InstallParams(TypedDict):
@ -37,6 +38,10 @@ class InstallParams(TypedDict):
permissions: str
class AppIntegrationTypeConfig(TypedDict):
oauth2_install_params: NotRequired[InstallParams]
class BaseAppInfo(TypedDict):
id: Snowflake
name: str
@ -44,10 +49,18 @@ class BaseAppInfo(TypedDict):
icon: Optional[str]
summary: str
description: str
flags: int
approximate_user_install_count: NotRequired[int]
cover_image: NotRequired[str]
terms_of_service_url: NotRequired[str]
privacy_policy_url: NotRequired[str]
rpc_origins: NotRequired[List[str]]
interactions_endpoint_url: NotRequired[Optional[str]]
redirect_uris: NotRequired[List[str]]
role_connections_verification_url: NotRequired[Optional[str]]
class AppInfo(BaseAppInfo):
rpc_origins: List[str]
owner: User
bot_public: bool
bot_require_code_grant: bool
@ -55,26 +68,24 @@ class AppInfo(BaseAppInfo):
guild_id: NotRequired[Snowflake]
primary_sku_id: NotRequired[Snowflake]
slug: NotRequired[str]
terms_of_service_url: NotRequired[str]
privacy_policy_url: NotRequired[str]
hook: NotRequired[bool]
max_participants: NotRequired[int]
tags: NotRequired[List[str]]
install_params: NotRequired[InstallParams]
custom_install_url: NotRequired[str]
role_connections_verification_url: NotRequired[str]
integration_types_config: NotRequired[Dict[Literal['0', '1'], AppIntegrationTypeConfig]]
class PartialAppInfo(BaseAppInfo, total=False):
rpc_origins: List[str]
cover_image: str
hook: bool
terms_of_service_url: str
privacy_policy_url: str
max_participants: int
flags: int
approximate_guild_count: int
class GatewayAppInfo(TypedDict):
id: Snowflake
flags: int
class ListAppEmojis(TypedDict):
items: List[Emoji]

24
discord/types/audit_log.py

@ -37,6 +37,7 @@ from .role import Role
from .channel import ChannelType, DefaultReaction, PrivacyLevel, VideoQualityMode, PermissionOverwrite, ForumTag
from .threads import Thread
from .command import ApplicationCommand, ApplicationCommandPermissions
from .automod import AutoModerationTriggerMetadata
AuditLogEvent = Literal[
1,
@ -87,12 +88,17 @@ AuditLogEvent = Literal[
111,
112,
121,
130,
131,
132,
140,
141,
142,
143,
144,
145,
150,
151,
]
@ -109,6 +115,7 @@ class _AuditLogChange_Str(TypedDict):
'permissions',
'tags',
'unicode_emoji',
'emoji_name',
]
new_value: str
old_value: str
@ -133,6 +140,8 @@ class _AuditLogChange_Snowflake(TypedDict):
'channel_id',
'inviter_id',
'guild_id',
'user_id',
'sound_id',
]
new_value: Snowflake
old_value: Snowflake
@ -180,6 +189,12 @@ class _AuditLogChange_Int(TypedDict):
old_value: int
class _AuditLogChange_Float(TypedDict):
key: Literal['volume']
new_value: float
old_value: float
class _AuditLogChange_ListRole(TypedDict):
key: Literal['$add', '$remove']
new_value: List[Role]
@ -276,11 +291,18 @@ class _AuditLogChange_DefaultReactionEmoji(TypedDict):
old_value: Optional[DefaultReaction]
class _AuditLogChange_TriggerMetadata(TypedDict):
key: Literal['trigger_metadata']
new_value: Optional[AutoModerationTriggerMetadata]
old_value: Optional[AutoModerationTriggerMetadata]
AuditLogChange = Union[
_AuditLogChange_Str,
_AuditLogChange_AssetHash,
_AuditLogChange_Snowflake,
_AuditLogChange_Int,
_AuditLogChange_Float,
_AuditLogChange_Bool,
_AuditLogChange_ListRole,
_AuditLogChange_MFALevel,
@ -298,6 +320,7 @@ AuditLogChange = Union[
_AuditLogChange_AppliedTags,
_AuditLogChange_AvailableTags,
_AuditLogChange_DefaultReactionEmoji,
_AuditLogChange_TriggerMetadata,
]
@ -314,6 +337,7 @@ class AuditEntryInfo(TypedDict):
guild_id: Snowflake
auto_moderation_rule_name: str
auto_moderation_rule_trigger_type: str
integration_type: str
class AuditLogEntry(TypedDict):

7
discord/types/automod.py

@ -45,9 +45,13 @@ class _AutoModerationActionMetadataTimeout(TypedDict):
duration_seconds: int
class _AutoModerationActionMetadataCustomMessage(TypedDict):
custom_message: str
class _AutoModerationActionBlockMessage(TypedDict):
type: Literal[1]
metadata: NotRequired[Empty]
metadata: NotRequired[_AutoModerationActionMetadataCustomMessage]
class _AutoModerationActionAlert(TypedDict):
@ -75,6 +79,7 @@ class _AutoModerationTriggerMetadataKeywordPreset(TypedDict):
class _AutoModerationTriggerMetadataMentionLimit(TypedDict):
mention_total_limit: int
mention_raid_protection_enabled: bool
AutoModerationTriggerMetadata = Union[

44
discord/types/channel.py

@ -28,6 +28,7 @@ from typing_extensions import NotRequired
from .user import PartialUser
from .snowflake import Snowflake
from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration, ThreadType
from .emoji import PartialEmoji
OverwriteType = Literal[0, 1]
@ -40,7 +41,7 @@ class PermissionOverwrite(TypedDict):
deny: str
ChannelTypeWithoutThread = Literal[0, 1, 2, 3, 4, 5, 6, 13, 15]
ChannelTypeWithoutThread = Literal[0, 1, 2, 3, 4, 5, 6, 13, 15, 16]
ChannelType = Union[ChannelTypeWithoutThread, ThreadType]
@ -89,6 +90,20 @@ class VoiceChannel(_BaseTextChannel):
video_quality_mode: NotRequired[VideoQualityMode]
VoiceChannelEffectAnimationType = Literal[0, 1]
class VoiceChannelEffect(TypedDict):
guild_id: Snowflake
channel_id: Snowflake
user_id: Snowflake
emoji: NotRequired[Optional[PartialEmoji]]
animation_type: NotRequired[VoiceChannelEffectAnimationType]
animation_id: NotRequired[int]
sound_id: NotRequired[Union[int, str]]
sound_volume: NotRequired[float]
class CategoryChannel(_BaseGuildChannel):
type: Literal[4]
@ -134,30 +149,49 @@ class ForumTag(TypedDict):
emoji_name: Optional[str]
ForumOrderType = Literal[0, 1]
ForumLayoutType = Literal[0, 1, 2]
class ForumChannel(_BaseTextChannel):
type: Literal[15]
class _BaseForumChannel(_BaseTextChannel):
available_tags: List[ForumTag]
default_reaction_emoji: Optional[DefaultReaction]
default_sort_order: Optional[ForumOrderType]
default_forum_layout: NotRequired[ForumLayoutType]
flags: NotRequired[int]
GuildChannel = Union[TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StageChannel, ThreadChannel, ForumChannel]
class ForumChannel(_BaseForumChannel):
type: Literal[15]
class MediaChannel(_BaseForumChannel):
type: Literal[16]
class DMChannel(_BaseChannel):
GuildChannel = Union[
TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StageChannel, ThreadChannel, ForumChannel, MediaChannel
]
class _BaseDMChannel(_BaseChannel):
type: Literal[1]
last_message_id: Optional[Snowflake]
class DMChannel(_BaseDMChannel):
recipients: List[PartialUser]
class InteractionDMChannel(_BaseDMChannel):
recipients: NotRequired[List[PartialUser]]
class GroupDMChannel(_BaseChannel):
type: Literal[3]
icon: Optional[str]
owner_id: Snowflake
recipients: List[PartialUser]
Channel = Union[GuildChannel, DMChannel, GroupDMChannel]

4
discord/types/command.py

@ -29,9 +29,11 @@ from typing_extensions import NotRequired, Required
from .channel import ChannelType
from .snowflake import Snowflake
from .interactions import InteractionContextType
ApplicationCommandType = Literal[1, 2, 3]
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
ApplicationIntegrationType = Literal[0, 1]
class _BaseApplicationCommandOption(TypedDict):
@ -141,6 +143,8 @@ class _BaseApplicationCommand(TypedDict):
id: Snowflake
application_id: Snowflake
name: str
contexts: List[InteractionContextType]
integration_types: List[ApplicationIntegrationType]
dm_permission: NotRequired[Optional[bool]]
default_member_permissions: NotRequired[Optional[str]]
nsfw: NotRequired[bool]

14
discord/types/components.py

@ -31,8 +31,9 @@ from .emoji import PartialEmoji
from .channel import ChannelType
ComponentType = Literal[1, 2, 3, 4]
ButtonStyle = Literal[1, 2, 3, 4, 5]
ButtonStyle = Literal[1, 2, 3, 4, 5, 6]
TextStyle = Literal[1, 2]
DefaultValueType = Literal['user', 'role', 'channel']
class ActionRow(TypedDict):
@ -48,6 +49,7 @@ class ButtonComponent(TypedDict):
disabled: NotRequired[bool]
emoji: NotRequired[PartialEmoji]
label: NotRequired[str]
sku_id: NotRequired[str]
class SelectOption(TypedDict):
@ -66,6 +68,11 @@ class SelectComponent(TypedDict):
disabled: NotRequired[bool]
class SelectDefaultValues(TypedDict):
id: int
type: DefaultValueType
class StringSelectComponent(SelectComponent):
type: Literal[3]
options: NotRequired[List[SelectOption]]
@ -73,19 +80,23 @@ class StringSelectComponent(SelectComponent):
class UserSelectComponent(SelectComponent):
type: Literal[5]
default_values: NotRequired[List[SelectDefaultValues]]
class RoleSelectComponent(SelectComponent):
type: Literal[6]
default_values: NotRequired[List[SelectDefaultValues]]
class MentionableSelectComponent(SelectComponent):
type: Literal[7]
default_values: NotRequired[List[SelectDefaultValues]]
class ChannelSelectComponent(SelectComponent):
type: Literal[8]
channel_types: NotRequired[List[ChannelType]]
default_values: NotRequired[List[SelectDefaultValues]]
class TextInput(TypedDict):
@ -104,6 +115,7 @@ class SelectMenu(SelectComponent):
type: Literal[3, 5, 6, 7, 8]
options: NotRequired[List[SelectOption]]
channel_types: NotRequired[List[ChannelType]]
default_values: NotRequired[List[SelectDefaultValues]]
ActionRowChildComponent = Union[ButtonComponent, SelectMenu, TextInput]

26
discord/types/embed.py

@ -38,25 +38,12 @@ class EmbedField(TypedDict):
inline: NotRequired[bool]
class EmbedThumbnail(TypedDict, total=False):
url: Required[str]
proxy_url: str
height: int
width: int
class EmbedVideo(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedImage(TypedDict, total=False):
class EmbedMedia(TypedDict, total=False):
url: Required[str]
proxy_url: str
height: int
width: int
flags: int
class EmbedProvider(TypedDict, total=False):
@ -71,7 +58,7 @@ class EmbedAuthor(TypedDict, total=False):
proxy_icon_url: str
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link', 'poll_result']
class Embed(TypedDict, total=False):
@ -82,9 +69,10 @@ class Embed(TypedDict, total=False):
timestamp: str
color: int
footer: EmbedFooter
image: EmbedImage
thumbnail: EmbedThumbnail
video: EmbedVideo
image: EmbedMedia
thumbnail: EmbedMedia
video: EmbedMedia
provider: EmbedProvider
author: EmbedAuthor
fields: List[EmbedField]
flags: int

2
discord/types/emoji.py

@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from typing import Optional, TypedDict
from typing_extensions import NotRequired
from .snowflake import Snowflake, SnowflakeList
from .user import User
@ -30,6 +31,7 @@ from .user import User
class PartialEmoji(TypedDict):
id: Optional[Snowflake]
name: Optional[str]
animated: NotRequired[bool]
class Emoji(PartialEmoji, total=False):

46
discord/types/gateway.py

@ -27,23 +27,26 @@ from typing_extensions import NotRequired, Required
from .automod import AutoModerationAction, AutoModerationRuleTriggerType
from .activity import PartialPresenceUpdate
from .sku import Entitlement
from .voice import GuildVoiceState
from .integration import BaseIntegration, IntegrationApplication
from .role import Role
from .channel import ChannelType, StageInstance
from .channel import ChannelType, StageInstance, VoiceChannelEffect
from .interactions import Interaction
from .invite import InviteTargetType
from .emoji import Emoji, PartialEmoji
from .member import MemberWithUser
from .snowflake import Snowflake
from .message import Message
from .message import Message, ReactionType
from .sticker import GuildSticker
from .appinfo import GatewayAppInfo, PartialAppInfo
from .guild import Guild, UnavailableGuild
from .user import User
from .user import User, AvatarDecorationData
from .threads import Thread, ThreadMember
from .scheduled_event import GuildScheduledEvent
from .audit_log import AuditLogEntry
from .soundboard import SoundboardSound
from .subscription import Subscription
class SessionStartLimit(TypedDict):
@ -89,8 +92,7 @@ class MessageDeleteBulkEvent(TypedDict):
guild_id: NotRequired[Snowflake]
class MessageUpdateEvent(Message):
channel_id: Snowflake
MessageUpdateEvent = MessageCreateEvent
class MessageReactionAddEvent(TypedDict):
@ -100,6 +102,10 @@ class MessageReactionAddEvent(TypedDict):
emoji: PartialEmoji
member: NotRequired[MemberWithUser]
guild_id: NotRequired[Snowflake]
message_author_id: NotRequired[Snowflake]
burst: bool
burst_colors: NotRequired[List[str]]
type: ReactionType
class MessageReactionRemoveEvent(TypedDict):
@ -108,6 +114,8 @@ class MessageReactionRemoveEvent(TypedDict):
message_id: Snowflake
emoji: PartialEmoji
guild_id: NotRequired[Snowflake]
burst: bool
type: ReactionType
class MessageReactionRemoveAllEvent(TypedDict):
@ -223,6 +231,7 @@ class GuildMemberUpdateEvent(TypedDict):
mute: NotRequired[bool]
pending: NotRequired[bool]
communication_disabled_until: NotRequired[str]
avatar_decoration_data: NotRequired[AvatarDecorationData]
class GuildEmojisUpdateEvent(TypedDict):
@ -311,6 +320,19 @@ class _GuildScheduledEventUsersEvent(TypedDict):
GuildScheduledEventUserAdd = GuildScheduledEventUserRemove = _GuildScheduledEventUsersEvent
VoiceStateUpdateEvent = GuildVoiceState
VoiceChannelEffectSendEvent = VoiceChannelEffect
GuildSoundBoardSoundCreateEvent = GuildSoundBoardSoundUpdateEvent = SoundboardSound
class GuildSoundBoardSoundsUpdateEvent(TypedDict):
guild_id: Snowflake
soundboard_sounds: List[SoundboardSound]
class GuildSoundBoardSoundDeleteEvent(TypedDict):
sound_id: Snowflake
guild_id: Snowflake
class VoiceServerUpdateEvent(TypedDict):
@ -343,3 +365,17 @@ class AutoModerationActionExecution(TypedDict):
class GuildAuditLogEntryCreate(AuditLogEntry):
guild_id: Snowflake
EntitlementCreateEvent = EntitlementUpdateEvent = EntitlementDeleteEvent = Entitlement
class PollVoteActionEvent(TypedDict):
user_id: Snowflake
channel_id: Snowflake
message_id: Snowflake
guild_id: NotRequired[Snowflake]
answer_id: int
SubscriptionCreateEvent = SubscriptionUpdateEvent = SubscriptionDeleteEvent = Subscription

24
discord/types/guild.py

@ -37,6 +37,7 @@ from .member import Member
from .emoji import Emoji
from .user import User
from .threads import Thread
from .soundboard import SoundboardSound
class Ban(TypedDict):
@ -49,6 +50,11 @@ class UnavailableGuild(TypedDict):
unavailable: NotRequired[bool]
class IncidentData(TypedDict):
invites_disabled_until: NotRequired[Optional[str]]
dms_disabled_until: NotRequired[Optional[str]]
DefaultMessageNotificationLevel = Literal[0, 1]
ExplicitContentFilterLevel = Literal[0, 1, 2]
MFALevel = Literal[0, 1]
@ -84,6 +90,9 @@ GuildFeature = Literal[
'VERIFIED',
'VIP_REGIONS',
'WELCOME_SCREEN_ENABLED',
'RAID_ALERTS_DISABLED',
'SOUNDBOARD',
'MORE_SOUNDBOARD',
]
@ -96,6 +105,7 @@ class _BaseGuildPreview(UnavailableGuild):
stickers: List[GuildSticker]
features: List[GuildFeature]
description: Optional[str]
incidents_data: Optional[IncidentData]
class _GuildPreviewUnique(TypedDict):
@ -147,6 +157,7 @@ class Guild(_BaseGuildPreview):
max_members: NotRequired[int]
premium_subscription_count: NotRequired[int]
max_video_channel_users: NotRequired[int]
soundboard_sounds: NotRequired[List[SoundboardSound]]
class InviteGuild(Guild, total=False):
@ -161,11 +172,15 @@ class GuildPrune(TypedDict):
pruned: Optional[int]
class GuildMFALevel(TypedDict):
level: MFALevel
class ChannelPositionUpdate(TypedDict):
id: Snowflake
position: Optional[int]
lock_permissions: Optional[bool]
parent_id: Optional[Snowflake]
lock_permissions: NotRequired[Optional[bool]]
parent_id: NotRequired[Optional[Snowflake]]
class _RolePositionRequired(TypedDict):
@ -174,3 +189,8 @@ class _RolePositionRequired(TypedDict):
class RolePositionUpdate(_RolePositionRequired, total=False):
position: Optional[Snowflake]
class BulkBanUserResponse(TypedDict):
banned_users: Optional[List[Snowflake]]
failed_users: Optional[List[Snowflake]]

113
discord/types/interactions.py

@ -24,22 +24,36 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, Union
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, Union, Optional
from typing_extensions import NotRequired
from .channel import ChannelTypeWithoutThread, ThreadMetadata
from .threads import ThreadType
from .channel import ChannelTypeWithoutThread, GuildChannel, InteractionDMChannel, GroupDMChannel
from .sku import Entitlement
from .threads import ThreadType, ThreadMetadata
from .member import Member
from .message import Attachment
from .role import Role
from .snowflake import Snowflake
from .user import User
from .guild import GuildFeature
if TYPE_CHECKING:
from .message import Message
InteractionType = Literal[1, 2, 3, 4, 5]
InteractionResponseType = Literal[
1,
4,
5,
6,
7,
8,
9,
10,
]
InteractionContextType = Literal[0, 1, 2]
InteractionInstallationType = Literal[0, 1]
class _BasePartialChannel(TypedDict):
@ -50,6 +64,14 @@ class _BasePartialChannel(TypedDict):
class PartialChannel(_BasePartialChannel):
type: ChannelTypeWithoutThread
topic: NotRequired[str]
position: int
nsfw: bool
flags: int
rate_limit_per_user: int
parent_id: Optional[Snowflake]
last_message_id: Optional[Snowflake]
last_pin_timestamp: NotRequired[str]
class PartialThread(_BasePartialChannel):
@ -67,6 +89,12 @@ class ResolvedData(TypedDict, total=False):
attachments: Dict[str, Attachment]
class PartialInteractionGuild(TypedDict):
id: Snowflake
locale: str
features: List[GuildFeature]
class _BaseApplicationCommandInteractionDataOption(TypedDict):
name: str
@ -203,10 +231,17 @@ class _BaseInteraction(TypedDict):
token: str
version: Literal[1]
guild_id: NotRequired[Snowflake]
guild: NotRequired[PartialInteractionGuild]
channel_id: NotRequired[Snowflake]
channel: Union[GuildChannel, InteractionDMChannel, GroupDMChannel]
app_permissions: NotRequired[str]
locale: NotRequired[str]
guild_locale: NotRequired[str]
entitlement_sku_ids: NotRequired[List[Snowflake]]
entitlements: NotRequired[List[Entitlement]]
authorizing_integration_owners: Dict[Literal['0', '1'], Snowflake]
context: NotRequired[InteractionContextType]
attachment_size_limit: int
class PingInteraction(_BaseInteraction):
@ -237,3 +272,75 @@ class MessageInteraction(TypedDict):
name: str
user: User
member: NotRequired[Member]
class _MessageInteractionMetadata(TypedDict):
id: Snowflake
user: User
authorizing_integration_owners: Dict[Literal['0', '1'], Snowflake]
original_response_message_id: NotRequired[Snowflake]
class _ApplicationCommandMessageInteractionMetadata(_MessageInteractionMetadata):
type: Literal[2]
# command_type: Literal[1, 2, 3, 4]
class UserApplicationCommandMessageInteractionMetadata(_ApplicationCommandMessageInteractionMetadata):
# command_type: Literal[2]
target_user: User
class MessageApplicationCommandMessageInteractionMetadata(_ApplicationCommandMessageInteractionMetadata):
# command_type: Literal[3]
target_message_id: Snowflake
ApplicationCommandMessageInteractionMetadata = Union[
_ApplicationCommandMessageInteractionMetadata,
UserApplicationCommandMessageInteractionMetadata,
MessageApplicationCommandMessageInteractionMetadata,
]
class MessageComponentMessageInteractionMetadata(_MessageInteractionMetadata):
type: Literal[3]
interacted_message_id: Snowflake
class ModalSubmitMessageInteractionMetadata(_MessageInteractionMetadata):
type: Literal[5]
triggering_interaction_metadata: Union[
ApplicationCommandMessageInteractionMetadata, MessageComponentMessageInteractionMetadata
]
MessageInteractionMetadata = Union[
ApplicationCommandMessageInteractionMetadata,
MessageComponentMessageInteractionMetadata,
ModalSubmitMessageInteractionMetadata,
]
class InteractionCallbackResponse(TypedDict):
id: Snowflake
type: InteractionType
activity_instance_id: NotRequired[str]
response_message_id: NotRequired[Snowflake]
response_message_loading: NotRequired[bool]
response_message_ephemeral: NotRequired[bool]
class InteractionCallbackActivity(TypedDict):
id: str
class InteractionCallbackResource(TypedDict):
type: InteractionResponseType
activity_instance: NotRequired[InteractionCallbackActivity]
message: NotRequired[Message]
class InteractionCallback(TypedDict):
interaction: InteractionCallbackResponse
resource: NotRequired[InteractionCallbackResource]

4
discord/types/invite.py

@ -35,6 +35,7 @@ from .user import PartialUser
from .appinfo import PartialAppInfo
InviteTargetType = Literal[1, 2]
InviteType = Literal[0, 1, 2]
class _InviteMetadata(TypedDict, total=False):
@ -63,6 +64,8 @@ class Invite(IncompleteInvite, total=False):
target_type: InviteTargetType
target_application: PartialAppInfo
guild_scheduled_event: GuildScheduledEvent
type: InviteType
flags: NotRequired[int]
class InviteWithCounts(Invite, _GuildPreviewUnique):
@ -82,6 +85,7 @@ class GatewayInviteCreate(TypedDict):
target_type: NotRequired[InviteTargetType]
target_user: NotRequired[PartialUser]
target_application: NotRequired[PartialAppInfo]
flags: NotRequired[int]
class GatewayInviteDelete(TypedDict):

8
discord/types/member.py

@ -24,7 +24,8 @@ DEALINGS IN THE SOFTWARE.
from typing import Optional, TypedDict
from .snowflake import SnowflakeList
from .user import User
from .user import User, AvatarDecorationData
from typing_extensions import NotRequired
class Nickname(TypedDict):
@ -33,7 +34,7 @@ class Nickname(TypedDict):
class PartialMember(TypedDict):
roles: SnowflakeList
joined_at: str
joined_at: Optional[str] # null if guest
deaf: bool
mute: bool
flags: int
@ -47,6 +48,8 @@ class Member(PartialMember, total=False):
pending: bool
permissions: str
communication_disabled_until: str
banner: NotRequired[Optional[str]]
avatar_decoration_data: NotRequired[AvatarDecorationData]
class _OptionalMemberWithUser(PartialMember, total=False):
@ -56,6 +59,7 @@ class _OptionalMemberWithUser(PartialMember, total=False):
pending: bool
permissions: str
communication_disabled_until: str
avatar_decoration_data: NotRequired[AvatarDecorationData]
class MemberWithUser(_OptionalMemberWithUser):

98
discord/types/message.py

@ -34,8 +34,10 @@ from .emoji import PartialEmoji
from .embed import Embed
from .channel import ChannelType
from .components import Component
from .interactions import MessageInteraction
from .interactions import MessageInteraction, MessageInteractionMetadata
from .sticker import StickerItem
from .threads import Thread
from .poll import Poll
class PartialMessage(TypedDict):
@ -50,10 +52,21 @@ class ChannelMention(TypedDict):
name: str
class ReactionCountDetails(TypedDict):
burst: int
normal: int
ReactionType = Literal[0, 1]
class Reaction(TypedDict):
count: int
me: bool
emoji: PartialEmoji
me_burst: bool
count_details: ReactionCountDetails
burst_colors: List[str]
class Attachment(TypedDict):
@ -68,6 +81,9 @@ class Attachment(TypedDict):
content_type: NotRequired[str]
spoiler: NotRequired[bool]
ephemeral: NotRequired[bool]
duration_secs: NotRequired[float]
waveform: NotRequired[str]
flags: NotRequired[int]
MessageActivityType = Literal[1, 2, 3, 5]
@ -86,7 +102,11 @@ class MessageApplication(TypedDict):
cover_image: NotRequired[str]
MessageReferenceType = Literal[0, 1]
class MessageReference(TypedDict, total=False):
type: MessageReferenceType
message_id: Snowflake
channel_id: Required[Snowflake]
guild_id: Snowflake
@ -100,11 +120,78 @@ class RoleSubscriptionData(TypedDict):
is_renewal: bool
PurchaseNotificationResponseType = Literal[0]
class GuildProductPurchase(TypedDict):
listing_id: Snowflake
product_name: str
class PurchaseNotificationResponse(TypedDict):
type: PurchaseNotificationResponseType
guild_product_purchase: Optional[GuildProductPurchase]
class CallMessage(TypedDict):
participants: SnowflakeList
ended_timestamp: NotRequired[Optional[str]]
MessageType = Literal[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
14,
15,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
36,
37,
38,
39,
44,
46,
]
class MessageSnapshot(TypedDict):
type: MessageType
content: str
embeds: List[Embed]
attachments: List[Attachment]
timestamp: str
edited_timestamp: Optional[str]
flags: NotRequired[int]
mentions: List[UserWithMember]
mention_roles: SnowflakeList
sticker_items: NotRequired[List[StickerItem]]
components: NotRequired[List[Component]]
class Message(PartialMessage):
id: Snowflake
author: User
@ -118,6 +205,7 @@ class Message(PartialMessage):
attachments: List[Attachment]
embeds: List[Embed]
pinned: bool
poll: NotRequired[Poll]
type: MessageType
member: NotRequired[Member]
mention_channels: NotRequired[List[ChannelMention]]
@ -131,10 +219,14 @@ class Message(PartialMessage):
flags: NotRequired[int]
sticker_items: NotRequired[List[StickerItem]]
referenced_message: NotRequired[Optional[Message]]
interaction: NotRequired[MessageInteraction]
interaction: NotRequired[MessageInteraction] # deprecated, use interaction_metadata
interaction_metadata: NotRequired[MessageInteractionMetadata]
components: NotRequired[List[Component]]
position: NotRequired[int]
role_subscription_data: NotRequired[RoleSubscriptionData]
thread: NotRequired[Thread]
call: NotRequired[CallMessage]
purchase_notification: NotRequired[PurchaseNotificationResponse]
AllowedMentionType = Literal['roles', 'users', 'everyone']

88
discord/types/poll.py

@ -0,0 +1,88 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, TypedDict, Optional, Literal, TYPE_CHECKING
from typing_extensions import NotRequired
from .snowflake import Snowflake
if TYPE_CHECKING:
from .user import User
from .emoji import PartialEmoji
LayoutType = Literal[1] # 1 = Default
class PollMedia(TypedDict):
text: str
emoji: NotRequired[Optional[PartialEmoji]]
class PollAnswer(TypedDict):
poll_media: PollMedia
class PollAnswerWithID(PollAnswer):
answer_id: int
class PollAnswerCount(TypedDict):
id: Snowflake
count: int
me_voted: bool
class PollAnswerVoters(TypedDict):
users: List[User]
class PollResult(TypedDict):
is_finalized: bool
answer_counts: List[PollAnswerCount]
class PollCreate(TypedDict):
allow_multiselect: bool
answers: List[PollAnswer]
duration: float
layout_type: LayoutType
question: PollMedia
# We don't subclass Poll as it will
# still have the duration field, which
# is converted into expiry when poll is
# fetched from a message or returned
# by a `send` method in a Messageable
class Poll(TypedDict):
allow_multiselect: bool
answers: List[PollAnswerWithID]
expiry: str
layout_type: LayoutType
question: PollMedia
results: PollResult

1
discord/types/role.py

@ -39,6 +39,7 @@ class Role(TypedDict):
permissions: str
managed: bool
mentionable: bool
flags: int
icon: NotRequired[Optional[str]]
unicode_emoji: NotRequired[Optional[str]]
tags: NotRequired[RoleTags]

53
discord/types/sku.py

@ -0,0 +1,53 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TypedDict, Optional, Literal
from typing_extensions import NotRequired
class SKU(TypedDict):
id: str
type: int
application_id: str
name: str
slug: str
flags: int
class Entitlement(TypedDict):
id: str
sku_id: str
application_id: str
user_id: Optional[str]
type: int
deleted: bool
starts_at: NotRequired[str]
ends_at: NotRequired[str]
guild_id: NotRequired[str]
consumed: NotRequired[bool]
EntitlementOwnerType = Literal[1, 2]

49
discord/types/soundboard.py

@ -0,0 +1,49 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypedDict, Optional, Union
from typing_extensions import NotRequired
from .snowflake import Snowflake
from .user import User
class BaseSoundboardSound(TypedDict):
sound_id: Union[Snowflake, str] # basic string number when it's a default sound
volume: float
class SoundboardSound(BaseSoundboardSound):
name: str
emoji_name: Optional[str]
emoji_id: Optional[Snowflake]
user_id: NotRequired[Snowflake]
available: bool
guild_id: NotRequired[Snowflake]
user: NotRequired[User]
class SoundboardDefaultSound(BaseSoundboardSound):
name: str
emoji_name: str

4
discord/types/sticker.py

@ -30,7 +30,7 @@ from typing_extensions import NotRequired
from .snowflake import Snowflake
from .user import User
StickerFormatType = Literal[1, 2, 3]
StickerFormatType = Literal[1, 2, 3, 4]
class StickerItem(TypedDict):
@ -55,7 +55,7 @@ class StandardSticker(BaseSticker):
class GuildSticker(BaseSticker):
type: Literal[2]
available: bool
available: NotRequired[bool]
guild_id: Snowflake
user: NotRequired[User]

43
discord/types/subscription.py

@ -0,0 +1,43 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TypedDict
from .snowflake import Snowflake
SubscriptionStatus = Literal[0, 1, 2]
class Subscription(TypedDict):
id: Snowflake
user_id: Snowflake
sku_ids: List[Snowflake]
entitlement_ids: List[Snowflake]
current_period_start: str
current_period_end: str
status: SubscriptionStatus
canceled_at: Optional[str]
renewal_sku_ids: Optional[List[Snowflake]]

3
discord/types/team.py

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import TypedDict, List, Optional
from typing import Literal, TypedDict, List, Optional
from .user import PartialUser
from .snowflake import Snowflake
@ -35,6 +35,7 @@ class TeamMember(TypedDict):
membership_state: int
permissions: List[str]
team_id: Snowflake
role: Literal['admin', 'developer', 'read_only']
class Team(TypedDict):

12
discord/types/user.py

@ -24,6 +24,12 @@ DEALINGS IN THE SOFTWARE.
from .snowflake import Snowflake
from typing import Literal, Optional, TypedDict
from typing_extensions import NotRequired
class AvatarDecorationData(TypedDict):
asset: str
sku_id: Snowflake
class PartialUser(TypedDict):
@ -31,16 +37,18 @@ class PartialUser(TypedDict):
username: str
discriminator: str
avatar: Optional[str]
global_name: Optional[str]
avatar_decoration_data: NotRequired[AvatarDecorationData]
PremiumType = Literal[0, 1, 2]
PremiumType = Literal[0, 1, 2, 3]
class User(PartialUser, total=False):
bot: bool
system: bool
mfa_enabled: bool
local: str
locale: str
verified: bool
email: Optional[str]
flags: int

7
discord/types/voice.py

@ -29,7 +29,12 @@ from .snowflake import Snowflake
from .member import MemberWithUser
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
SupportedModes = Literal[
'aead_xchacha20_poly1305_rtpsize',
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
'xsalsa20_poly1305',
]
class _VoiceState(TypedDict):

1
discord/ui/__init__.py

@ -15,3 +15,4 @@ from .item import *
from .button import *
from .select import *
from .text_input import *
from .dynamic import *

47
discord/ui/button.py

@ -61,12 +61,14 @@ class Button(Item[V]):
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
If this button is for a URL, it does not have a custom ID.
Can only be up to 100 characters.
url: Optional[:class:`str`]
The URL this button sends you to.
disabled: :class:`bool`
Whether the button is disabled or not.
label: Optional[:class:`str`]
The label of the button, if any.
Can only be up to 80 characters.
emoji: Optional[Union[:class:`.PartialEmoji`, :class:`.Emoji`, :class:`str`]]
The emoji of the button, if available.
row: Optional[:class:`int`]
@ -75,6 +77,11 @@ class Button(Item[V]):
like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
sku_id: Optional[:class:`int`]
The SKU ID this button sends you to. Can't be combined with ``url``, ``label``, ``emoji``
nor ``custom_id``.
.. versionadded:: 2.4
"""
__item_repr_attributes__: Tuple[str, ...] = (
@ -84,6 +91,7 @@ class Button(Item[V]):
'label',
'emoji',
'row',
'sku_id',
)
def __init__(
@ -96,13 +104,18 @@ class Button(Item[V]):
url: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
row: Optional[int] = None,
sku_id: Optional[int] = None,
):
super().__init__()
if custom_id is not None and url is not None:
raise TypeError('cannot mix both url and custom_id with Button')
if custom_id is not None and (url is not None or sku_id is not None):
raise TypeError('cannot mix both url or sku_id and custom_id with Button')
if url is not None and sku_id is not None:
raise TypeError('cannot mix both url and sku_id')
requires_custom_id = url is None and sku_id is None
self._provided_custom_id = custom_id is not None
if url is None and custom_id is None:
if requires_custom_id and custom_id is None:
custom_id = os.urandom(16).hex()
if custom_id is not None and not isinstance(custom_id, str):
@ -111,6 +124,9 @@ class Button(Item[V]):
if url is not None:
style = ButtonStyle.link
if sku_id is not None:
style = ButtonStyle.premium
if emoji is not None:
if isinstance(emoji, str):
emoji = PartialEmoji.from_str(emoji)
@ -126,6 +142,7 @@ class Button(Item[V]):
label=label,
style=style,
emoji=emoji,
sku_id=sku_id,
)
self.row = row
@ -200,6 +217,20 @@ class Button(Item[V]):
else:
self._underlying.emoji = None
@property
def sku_id(self) -> Optional[int]:
"""Optional[:class:`int`]: The SKU ID this button sends you to.
.. versionadded:: 2.4
"""
return self._underlying.sku_id
@sku_id.setter
def sku_id(self, value: Optional[int]) -> None:
if value is not None:
self.style = ButtonStyle.premium
self._underlying.sku_id = value
@classmethod
def from_component(cls, button: ButtonComponent) -> Self:
return cls(
@ -210,6 +241,7 @@ class Button(Item[V]):
url=button.url,
emoji=button.emoji,
row=None,
sku_id=button.sku_id,
)
@property
@ -248,19 +280,21 @@ def button(
.. note::
Buttons with a URL cannot be created with this function.
Buttons with a URL or an SKU cannot be created with this function.
Consider creating a :class:`Button` manually instead.
This is because buttons with a URL do not have a callback
This is because these buttons cannot have a callback
associated with them since Discord does not do any processing
with it.
with them.
Parameters
------------
label: Optional[:class:`str`]
The label of the button, if any.
Can only be up to 80 characters.
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
It is recommended not to set this parameter to prevent conflicts.
Can only be up to 100 characters.
style: :class:`.ButtonStyle`
The style of the button. Defaults to :attr:`.ButtonStyle.grey`.
disabled: :class:`bool`
@ -289,6 +323,7 @@ def button(
'label': label,
'emoji': emoji,
'row': row,
'sku_id': None,
}
return func

216
discord/ui/dynamic.py

@ -0,0 +1,216 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import ClassVar, Dict, Generic, Optional, Tuple, Type, TypeVar, TYPE_CHECKING, Any, Union
import re
from .item import Item
from .._types import ClientT
__all__ = ('DynamicItem',)
BaseT = TypeVar('BaseT', bound='Item[Any]', covariant=True)
if TYPE_CHECKING:
from typing_extensions import TypeVar, Self
from ..interactions import Interaction
from ..components import Component
from ..enums import ComponentType
from .view import View
V = TypeVar('V', bound='View', covariant=True, default=View)
else:
V = TypeVar('V', bound='View', covariant=True)
class DynamicItem(Generic[BaseT], Item['View']):
"""Represents an item with a dynamic ``custom_id`` that can be used to store state within
that ``custom_id``.
The ``custom_id`` parsing is done using the ``re`` module by passing a ``template``
parameter to the class parameter list.
This item is generated every time the component is dispatched. This means that
any variable that holds an instance of this class will eventually be out of date
and should not be used long term. Their only purpose is to act as a "template"
for the actual dispatched item.
When this item is generated, :attr:`view` is set to a regular :class:`View` instance
from the original message given from the interaction. This means that custom view
subclasses cannot be accessed from this item.
.. versionadded:: 2.4
Parameters
------------
item: :class:`Item`
The item to wrap with dynamic custom ID parsing.
template: Union[:class:`str`, ``re.Pattern``]
The template to use for parsing the ``custom_id``. This can be a string or a compiled
regular expression. This must be passed as a keyword argument to the class creation.
row: Optional[:class:`int`]
The relative row this button belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
Attributes
-----------
item: :class:`Item`
The item that is wrapped with dynamic custom ID parsing.
"""
__item_repr_attributes__: Tuple[str, ...] = (
'item',
'template',
)
__discord_ui_compiled_template__: ClassVar[re.Pattern[str]]
def __init_subclass__(cls, *, template: Union[str, re.Pattern[str]]) -> None:
super().__init_subclass__()
cls.__discord_ui_compiled_template__ = re.compile(template) if isinstance(template, str) else template
if not isinstance(cls.__discord_ui_compiled_template__, re.Pattern):
raise TypeError('template must be a str or a re.Pattern')
def __init__(
self,
item: BaseT,
*,
row: Optional[int] = None,
) -> None:
super().__init__()
self.item: BaseT = item
if row is not None:
self.row = row
if not self.item.is_dispatchable():
raise TypeError('item must be dispatchable, e.g. not a URL button')
if not self.template.match(self.custom_id):
raise ValueError(f'item custom_id {self.custom_id!r} must match the template {self.template.pattern!r}')
@property
def template(self) -> re.Pattern[str]:
"""``re.Pattern``: The compiled regular expression that is used to parse the ``custom_id``."""
return self.__class__.__discord_ui_compiled_template__
def to_component_dict(self) -> Dict[str, Any]:
return self.item.to_component_dict()
def _refresh_component(self, component: Component) -> None:
self.item._refresh_component(component)
def _refresh_state(self, interaction: Interaction, data: Dict[str, Any]) -> None:
self.item._refresh_state(interaction, data)
@classmethod
def from_component(cls: Type[Self], component: Component) -> Self:
raise TypeError('Dynamic items cannot be created from components')
@property
def type(self) -> ComponentType:
return self.item.type
def is_dispatchable(self) -> bool:
return self.item.is_dispatchable()
def is_persistent(self) -> bool:
return True
@property
def custom_id(self) -> str:
""":class:`str`: The ID of the dynamic item that gets received during an interaction."""
return self.item.custom_id # type: ignore # This attribute exists for dispatchable items
@custom_id.setter
def custom_id(self, value: str) -> None:
if not isinstance(value, str):
raise TypeError('custom_id must be a str')
if not self.template.match(value):
raise ValueError(f'custom_id must match the template {self.template.pattern!r}')
self.item.custom_id = value # type: ignore # This attribute exists for dispatchable items
self._provided_custom_id = True
@property
def row(self) -> Optional[int]:
return self.item._row
@row.setter
def row(self, value: Optional[int]) -> None:
self.item.row = value
@property
def width(self) -> int:
return self.item.width
@classmethod
async def from_custom_id(
cls: Type[Self], interaction: Interaction[ClientT], item: Item[Any], match: re.Match[str], /
) -> Self:
"""|coro|
A classmethod that is called when the ``custom_id`` of a component matches the
``template`` of the class. This is called when the component is dispatched.
It must return a new instance of the :class:`DynamicItem`.
Subclasses *must* implement this method.
Exceptions raised in this method are logged and ignored.
.. warning::
This method is called before the callback is dispatched, therefore
it means that it is subject to the same timing restrictions as the callback.
Ergo, you must reply to an interaction within 3 seconds of it being
dispatched.
Parameters
------------
interaction: :class:`~discord.Interaction`
The interaction that the component belongs to.
item: :class:`~discord.ui.Item`
The base item that is being dispatched.
match: ``re.Match``
The match object that was created from the ``template``
matching the ``custom_id``.
Returns
--------
:class:`DynamicItem`
The new instance of the :class:`DynamicItem` with information
from the ``match`` object.
"""
raise NotImplementedError
async def callback(self, interaction: Interaction[ClientT]) -> Any:
return await self.item.callback(interaction)
async def interaction_check(self, interaction: Interaction[ClientT], /) -> bool:
return await self.item.interaction_check(interaction)

35
discord/ui/item.py

@ -40,7 +40,7 @@ if TYPE_CHECKING:
from .view import View
from ..components import Component
I = TypeVar('I', bound='Item')
I = TypeVar('I', bound='Item[Any]')
V = TypeVar('V', bound='View', covariant=True)
ItemCallbackType = Callable[[V, Interaction[Any], I], Coroutine[Any, Any, Any]]
@ -133,3 +133,36 @@ class Item(Generic[V]):
The interaction that triggered this UI item.
"""
pass
async def interaction_check(self, interaction: Interaction[ClientT], /) -> bool:
"""|coro|
A callback that is called when an interaction happens within this item
that checks whether the callback should be processed.
This is useful to override if, for example, you want to ensure that the
interaction author is a given user.
The default implementation of this returns ``True``.
.. note::
If an exception occurs within the body then the check
is considered a failure and :meth:`discord.ui.View.on_error` is called.
For :class:`~discord.ui.DynamicItem` this does not call the ``on_error``
handler.
.. versionadded:: 2.4
Parameters
-----------
interaction: :class:`~discord.Interaction`
The interaction that occurred.
Returns
---------
:class:`bool`
Whether the callback should be called.
"""
return True

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save